From c3e41872a81e9ece6fd528c25a9b3fcab847ccf9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Aug 2019 15:53:55 -0500 Subject: [PATCH] Resplit Begin and BeginEx This is in preparation for a Begin / Tx interface that will similate nested transactions with savepoints. In addition, this passes the TxOptions struct by value and thereby removes an allocation. --- batch_test.go | 4 ++-- bench_test.go | 4 ++-- large_objects_test.go | 6 +++--- pgxpool/conn.go | 8 ++++++-- pgxpool/pool.go | 7 +++++-- pgxpool/tx_test.go | 10 +++++----- stdlib/sql.go | 2 +- tx.go | 18 ++++++++++-------- tx_test.go | 20 ++++++++++---------- 9 files changed, 44 insertions(+), 35 deletions(-) diff --git a/batch_test.go b/batch_test.go index dd65cb9a..82f692ba 100644 --- a/batch_test.go +++ b/batch_test.go @@ -463,7 +463,7 @@ func TestTxSendBatch(t *testing.T) { );` mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background(), nil) + tx, _ := conn.Begin(context.Background()) batch := &pgx.Batch{} batch.Queue("insert into ledger1(description) values($1) returning id", []interface{}{"q1"}, @@ -538,7 +538,7 @@ func TestTxSendBatchRollback(t *testing.T) { );` mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background(), nil) + tx, _ := conn.Begin(context.Background()) batch := &pgx.Batch{} batch.Queue("insert into ledger1(description) values($1) returning id", []interface{}{"q1"}, diff --git a/bench_test.go b/bench_test.go index 6ff2e239..a85de09d 100644 --- a/bench_test.go +++ b/bench_test.go @@ -413,7 +413,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { b.Fatal(err) } @@ -452,7 +452,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc } resetQuery() - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { return 0, err } diff --git a/large_objects_test.go b/large_objects_test.go index 2d8dd9c4..f2fa6016 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -22,7 +22,7 @@ func TestLargeObjects(t *testing.T) { t.Fatal(err) } - tx, err := conn.Begin(ctx, nil) + tx, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } @@ -135,7 +135,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { t.Fatal(err) } - tx, err := conn.Begin(ctx, nil) + tx, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } @@ -175,7 +175,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { rows.Close() // Start a new transaction - tx2, err := conn.Begin(ctx, nil) + tx2, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } diff --git a/pgxpool/conn.go b/pgxpool/conn.go index a1adb2e6..93d77044 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -66,8 +66,12 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNam return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) } -func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) { - return c.Conn().Begin(ctx, txOptions) +func (c *Conn) Begin(ctx context.Context) (*pgx.Tx, error) { + return c.Conn().Begin(ctx) +} + +func (c *Conn) BeginEx(ctx context.Context, txOptions pgx.TxOptions) (*pgx.Tx, error) { + return c.Conn().BeginEx(ctx, txOptions) } func (c *Conn) Conn() *pgx.Conn { diff --git a/pgxpool/pool.go b/pgxpool/pool.go index a5bf6873..6ba12568 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -352,13 +352,16 @@ func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return &poolBatchResults{br: br, c: c} } -func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) { +func (p *Pool) Begin(ctx context.Context) (*Tx, error) { + return p.BeginEx(ctx, pgx.TxOptions{}) +} +func (p *Pool) BeginEx(ctx context.Context, txOptions pgx.TxOptions) (*Tx, error) { c, err := p.Acquire(ctx) if err != nil { return nil, err } - t, err := c.Begin(ctx, txOptions) + t, err := c.BeginEx(ctx, txOptions) if err != nil { return nil, err } diff --git a/pgxpool/tx_test.go b/pgxpool/tx_test.go index adfa3739..d66ad338 100644 --- a/pgxpool/tx_test.go +++ b/pgxpool/tx_test.go @@ -16,7 +16,7 @@ func TestTxExec(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background(), nil) + tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) @@ -30,7 +30,7 @@ func TestTxQuery(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background(), nil) + tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) @@ -44,7 +44,7 @@ func TestTxQueryRow(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background(), nil) + tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) @@ -58,7 +58,7 @@ func TestTxSendBatch(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background(), nil) + tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) @@ -72,7 +72,7 @@ func TestTxCopyFrom(t *testing.T) { require.NoError(t, err) defer pool.Close() - tx, err := pool.Begin(context.Background(), nil) + tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) diff --git a/stdlib/sql.go b/stdlib/sql.go index 78fac3cf..5f7c2690 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -209,7 +209,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.AccessMode = pgx.ReadOnly } - tx, err := c.conn.Begin(ctx, &pgxOpts) + tx, err := c.conn.BeginEx(ctx, pgxOpts) if err != nil { return nil, err } diff --git a/tx.go b/tx.go index 10cabe6e..53effd66 100644 --- a/tx.go +++ b/tx.go @@ -50,11 +50,7 @@ type TxOptions struct { DeferrableMode TxDeferrableMode } -func (txOptions *TxOptions) beginSQL() string { - if txOptions == nil { - return "begin" - } - +func (txOptions TxOptions) beginSQL() string { buf := &bytes.Buffer{} buf.WriteString("begin") if txOptions.IsoLevel != "" { @@ -78,9 +74,15 @@ var ErrTxInFailure = errors.New("tx failed") // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") -// BeginEx starts a transaction with txOptions determining the transaction mode. txOptions can be nil. Unlike -// database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancelation. -func (c *Conn) Begin(ctx context.Context, txOptions *TxOptions) (*Tx, error) { +// Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no +// auto-rollback on context cancelation. +func (c *Conn) Begin(ctx context.Context) (*Tx, error) { + return c.BeginEx(ctx, TxOptions{}) +} + +// BeginEx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only +// affects the begin command. i.e. there is no auto-rollback on context cancelation. +func (c *Conn) BeginEx(ctx context.Context, txOptions TxOptions) (*Tx, error) { _, err := c.Exec(ctx, txOptions.beginSQL()) if err != nil { // begin should never fail unless there is an underlying connection issue or diff --git a/tx_test.go b/tx_test.go index 2a81f629..93edb048 100644 --- a/tx_test.go +++ b/tx_test.go @@ -26,7 +26,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -68,7 +68,7 @@ func TestTxCommitWhenTxBroken(t *testing.T) { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -113,13 +113,13 @@ func TestTxCommitSerializationFailure(t *testing.T) { } defer c1.Exec(context.Background(), `drop table tx_serializable_sums`) - tx1, err := c1.Begin(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx1, err := c1.BeginEx(context.Background(), pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("Begin failed: %v", err) } defer tx1.Rollback(context.Background()) - tx2, err := c2.Begin(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx2, err := c2.BeginEx(context.Background(), pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("Begin failed: %v", err) } @@ -163,7 +163,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) { t.Fatalf("Failed to create table: %v", err) } - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -196,7 +196,7 @@ func TestBeginIsoLevels(t *testing.T) { isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.Begin(context.Background(), &pgx.TxOptions{IsoLevel: iso}) + tx, err := conn.BeginEx(context.Background(), pgx.TxOptions{IsoLevel: iso}) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -220,7 +220,7 @@ func TestBeginReadOnly(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) + tx, err := conn.BeginEx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } @@ -238,7 +238,7 @@ func TestTxStatus(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } @@ -262,7 +262,7 @@ func TestTxStatusErrorInTransactions(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } @@ -309,7 +309,7 @@ func TestTxErr(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - tx, err := conn.Begin(context.Background(), nil) + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) }