diff --git a/batch_test.go b/batch_test.go index 1c37093a..186dcfd4 100644 --- a/batch_test.go +++ b/batch_test.go @@ -215,122 +215,6 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null -);` - mustExec(t, conn, sql) - - batch := conn.BeginBatch() - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q1", 1}, - []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - - ctx, cancelFn := context.WithCancel(context.Background()) - - err := batch.Send(ctx) - if err != nil { - t.Fatal(err) - } - - cancelFn() - - _, err = batch.ExecResults() - if err != context.Canceled { - t.Errorf("err => %v, want %v", err, context.Canceled) - } - - ensureConnValid(t, conn) -} - -func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - - batch := conn.BeginBatch() - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - - ctx, cancelFn := context.WithCancel(context.Background()) - - err := batch.Send(ctx) - if err != nil { - t.Fatal(err) - } - - cancelFn() - - rows, err := batch.QueryResults() - - if rows.Next() { - t.Error("unexpected row") - } - - if rows.Err() != context.Canceled { - t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) - } - - batch.Close() - - ensureConnValid(t, conn) -} - -func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - - batch := conn.BeginBatch() - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - batch.Queue("select pg_sleep(2)", - nil, - nil, - nil, - ) - - ctx, cancelFn := context.WithCancel(context.Background()) - - err := batch.Send(ctx) - if err != nil { - t.Fatal(err) - } - - cancelFn() - - err = batch.Close() - if err != context.Canceled { - t.Errorf("err => %v, want %v", err, context.Canceled) - } - - ensureConnValid(t, conn) -} - func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() diff --git a/conn_test.go b/conn_test.go index 40074456..35b0f414 100644 --- a/conn_test.go +++ b/conn_test.go @@ -241,25 +241,6 @@ func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { } } -func TestExecContextCancelationCancelsQuery(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - _, err := conn.Exec(ctx, "select pg_sleep(1)") - cancel() - if err != context.DeadlineExceeded { - t.Fatalf("Expected context.DeadlineExceeded err, got %v", err) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - - ensureConnValid(t, conn) -} - func TestExecFailureCloseBefore(t *testing.T) { t.Parallel() diff --git a/query_test.go b/query_test.go index 2d638784..1ddbda57 100644 --- a/query_test.go +++ b/query_test.go @@ -1258,37 +1258,6 @@ func TestQueryExContextErrorWhileReceivingRows(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryExContextCancelationCancelsQuery(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(500 * time.Millisecond) - cancelFunc() - }() - - rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil) - if err != nil { - t.Fatal(err) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - - for rows.Next() { - t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") - } - - if rows.Err() != context.Canceled { - t.Fatalf("Expected context.Canceled error, got %v", rows.Err()) - } - - ensureConnValid(t, conn) -} - func TestQueryRowExContextSuccess(t *testing.T) { t.Parallel() @@ -1331,30 +1300,6 @@ func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - go func() { - time.Sleep(500 * time.Millisecond) - cancelFunc() - }() - - var result []byte - err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result) - if err != context.Canceled { - t.Fatalf("Expected context.Canceled error, got %v", err) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - - ensureConnValid(t, conn) -} - func TestConnQueryRowExSingleRoundTrip(t *testing.T) { t.Parallel() diff --git a/tx_test.go b/tx_test.go index 4b6142fe..4865a084 100644 --- a/tx_test.go +++ b/tx_test.go @@ -2,15 +2,11 @@ package pgx_test import ( "context" - "fmt" "os" "testing" - "time" "github.com/jackc/pgconn" - "github.com/jackc/pgproto3" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgmock" ) func TestTransactionSuccessfulCommit(t *testing.T) { @@ -233,107 +229,6 @@ func TestBeginExReadOnly(t *testing.T) { } } -func TestConnBeginExContextCancel(t *testing.T) { - t.Parallel() - - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), - pgmock.WaitForClose(), - ) - - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - defer server.Close() - - errChan := make(chan error, 1) - go func() { - errChan <- server.ServeOne() - }() - - pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatal(err) - } - - conn := mustConnect(t, pgx.ConnConfig{Config: *pc}) - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - _, err = conn.BeginEx(ctx, nil) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) - } - - if conn.IsAlive() { - t.Error("expected conn to be dead after BeginEx failure") - } - - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } -} - -func TestTxCommitExCancel(t *testing.T) { - t.Parallel() - - script := &pgmock.Script{ - Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), - } - script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) - script.Steps = append(script.Steps, - pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), - pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), - pgmock.WaitForClose(), - ) - - server, err := pgmock.NewServer(script) - if err != nil { - t.Fatal(err) - } - defer server.Close() - - errChan := make(chan error, 1) - go func() { - errChan <- server.ServeOne() - }() - - pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) - if err != nil { - t.Fatal(err) - } - - conn := mustConnect(t, pgx.ConnConfig{Config: *pc}) - defer conn.Close() - - tx, err := conn.Begin() - if err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - err = tx.CommitEx(ctx) - if err != context.DeadlineExceeded { - t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) - } - - if conn.IsAlive() { - t.Error("expected conn to be dead after CommitEx failure") - } - - if err := <-errChan; err != nil { - t.Errorf("mock server err: %v", err) - } -} - func TestTxStatus(t *testing.T) { t.Parallel()