diff --git a/batch_test.go b/batch_test.go index c93c3dfa..fd65985e 100644 --- a/batch_test.go +++ b/batch_test.go @@ -484,7 +484,7 @@ func TestTxBeginBatch(t *testing.T) { );` mustExec(t, conn, sql) - tx, _ := conn.Begin() + tx, _ := conn.Begin(context.Background(), nil) batch := tx.BeginBatch() batch.Queue("insert into ledger1(description) values($1) returning id", []interface{}{"q1"}, @@ -563,7 +563,7 @@ func TestTxBeginBatchRollback(t *testing.T) { );` mustExec(t, conn, sql) - tx, _ := conn.Begin() + tx, _ := conn.Begin(context.Background(), nil) batch := tx.BeginBatch() batch.Queue("insert into ledger1(description) values($1) returning id", []interface{}{"q1"}, diff --git a/bench_test.go b/bench_test.go index 39f3f324..48433ff3 100644 --- a/bench_test.go +++ b/bench_test.go @@ -349,7 +349,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { b.Fatal(err) } @@ -388,7 +388,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc } resetQuery() - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { return 0, err } diff --git a/large_objects_test.go b/large_objects_test.go index 856ac397..332699ec 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -19,7 +19,7 @@ func TestLargeObjects(t *testing.T) { t.Fatal(err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -133,7 +133,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { t.Fatal(err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { rows.Close() // Start a new transaction - tx2, err := conn.Begin() + tx2, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/pool/conn.go b/pool/conn.go index 616f1f40..68dbc299 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -61,8 +61,8 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return c.Conn().QueryRow(ctx, sql, args...) } -func (c *Conn) Begin() (*pgx.Tx, error) { - return c.Conn().Begin() +func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) { + return c.Conn().Begin(ctx, txOptions) } func (c *Conn) Conn() *pgx.Conn { diff --git a/pool/pool.go b/pool/pool.go index 24587705..ed459735 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -127,13 +127,13 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return &poolRow{r: row, c: c} } -func (p *Pool) Begin() (*Tx, error) { - c, err := p.Acquire(context.Background()) +func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) { + c, err := p.Acquire(ctx) if err != nil { return nil, err } - t, err := c.Begin() + t, err := c.Begin(ctx, txOptions) if err != nil { return nil, err } diff --git a/pool/tx_test.go b/pool/tx_test.go index 7195dae6..3ec4a0ce 100644 --- a/pool/tx_test.go +++ b/pool/tx_test.go @@ -14,7 +14,7 @@ func TestTxExec(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin() + tx, err := pool.Begin(context.Background(), nil) require.NoError(t, err) defer tx.Rollback(context.Background()) @@ -26,7 +26,7 @@ func TestTxQuery(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin() + tx, err := pool.Begin(context.Background(), nil) require.NoError(t, err) defer tx.Rollback(context.Background()) @@ -38,7 +38,7 @@ func TestTxQueryRow(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin() + tx, err := pool.Begin(context.Background(), nil) require.NoError(t, err) defer tx.Rollback(context.Background()) diff --git a/stdlib/sql.go b/stdlib/sql.go index 5a49ce0b..0cfafe2c 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -210,7 +210,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.AccessMode = pgx.ReadOnly } - tx, err := c.conn.BeginEx(ctx, &pgxOpts) + tx, err := c.conn.Begin(ctx, &pgxOpts) if err != nil { return nil, err } diff --git a/tx.go b/tx.go index e840c741..b28603b5 100644 --- a/tx.go +++ b/tx.go @@ -78,16 +78,9 @@ var ErrTxInFailure = errors.New("tx failed") // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") -// Begin starts a transaction with the default transaction mode for the -// current connection. To use a specific transaction mode see BeginEx. -func (c *Conn) Begin() (*Tx, error) { - return c.BeginEx(context.Background(), nil) -} - -// BeginEx starts a transaction with txOptions determining the transaction -// mode. Unlike database/sql, the context only affects the begin command. i.e. -// there is no auto-rollback on context cancelation. -func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { +// BeginEx starts a transaction with txOptions determining the transaction mode. txOptions can be nil. Unlike +// database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancelation. +func (c *Conn) Begin(ctx context.Context, txOptions *TxOptions) (*Tx, error) { _, err := c.Exec(ctx, txOptions.beginSQL()) if err != nil { // begin should never fail unless there is an underlying connection issue or diff --git a/tx_test.go b/tx_test.go index 22ca14c3..2a81f629 100644 --- a/tx_test.go +++ b/tx_test.go @@ -26,7 +26,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -68,7 +68,7 @@ func TestTxCommitWhenTxBroken(t *testing.T) { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -113,15 +113,15 @@ func TestTxCommitSerializationFailure(t *testing.T) { } defer c1.Exec(context.Background(), `drop table tx_serializable_sums`) - tx1, err := c1.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx1, err := c1.Begin(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginEx failed: %v", err) + t.Fatalf("Begin failed: %v", err) } defer tx1.Rollback(context.Background()) - tx2, err := c2.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx2, err := c2.Begin(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginEx failed: %v", err) + t.Fatalf("Begin failed: %v", err) } defer tx2.Rollback(context.Background()) @@ -163,7 +163,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -188,7 +188,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) { } } -func TestBeginExIsoLevels(t *testing.T) { +func TestBeginIsoLevels(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -196,9 +196,9 @@ func TestBeginExIsoLevels(t *testing.T) { isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso}) + tx, err := conn.Begin(context.Background(), &pgx.TxOptions{IsoLevel: iso}) if err != nil { - t.Fatalf("conn.BeginEx failed: %v", err) + t.Fatalf("conn.Begin failed: %v", err) } var level pgx.TxIsoLevel @@ -214,15 +214,15 @@ func TestBeginExIsoLevels(t *testing.T) { } } -func TestBeginExReadOnly(t *testing.T) { +func TestBeginReadOnly(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) + tx, err := conn.Begin(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { - t.Fatalf("conn.BeginEx failed: %v", err) + t.Fatalf("conn.Begin failed: %v", err) } defer tx.Rollback(context.Background()) @@ -238,7 +238,7 @@ func TestTxStatus(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -262,7 +262,7 @@ func TestTxStatusErrorInTransactions(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -309,7 +309,7 @@ func TestTxErr(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin() + tx, err := conn.Begin(context.Background(), nil) if err != nil { t.Fatal(err) }