diff --git a/tx_test.go b/tx_test.go index a6483941..854038c4 100644 --- a/tx_test.go +++ b/tx_test.go @@ -5,6 +5,7 @@ import ( "errors" "os" "testing" + "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" @@ -157,41 +158,44 @@ func TestTxCommitSerializationFailure(t *testing.T) { c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, c2) - c1.Exec(context.Background(), `drop table if exists tx_serializable_sums`) - _, err := c1.Exec(context.Background(), `create table tx_serializable_sums(num integer);`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + c1.Exec(ctx, `drop table if exists tx_serializable_sums`) + _, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`) if err != nil { t.Fatalf("Unable to create temporary table: %v", err) } - defer c1.Exec(context.Background(), `drop table tx_serializable_sums`) + defer c1.Exec(ctx, `drop table tx_serializable_sums`) - tx1, err := c1.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("Begin failed: %v", err) } - defer tx1.Rollback(context.Background()) + defer tx1.Rollback(ctx) - tx2, err := c2.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("Begin failed: %v", err) } - defer tx2.Rollback(context.Background()) + defer tx2.Rollback(ctx) - _, err = tx1.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) + _, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) if err != nil { t.Fatalf("Exec failed: %v", err) } - _, err = tx2.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) + _, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) if err != nil { t.Fatalf("Exec failed: %v", err) } - err = tx1.Commit(context.Background()) + err = tx1.Commit(ctx) if err != nil { t.Fatalf("Commit failed: %v", err) } - err = tx2.Commit(context.Background()) + err = tx2.Commit(ctx) if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { t.Fatalf("Expected serialization error 40001, got %#v", err) }