diff --git a/CHANGELOG.md b/CHANGELOG.md index 4438e0bc..4e97434f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,6 +141,11 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent * `CollectOneRow` collects one row using `RowTo*` functions. * `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. +## Tx Helpers + +Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and +`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`. + ## SendBatch Uses Pipeline Mode When Appropriate Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 diff --git a/doc.go b/doc.go index 48971110..0fd3713b 100644 --- a/doc.go +++ b/doc.go @@ -247,10 +247,10 @@ These are internally implemented with savepoints. Use BeginTx to control the transaction mode. -BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the +BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the transaction depending on the return value of the function. These can be simpler and less error prone to use. - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") return err }) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index b8711da9..b9ff29dc 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -92,14 +92,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return c.Conn().BeginTx(ctx, txOptions) } -func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return c.Conn().BeginFunc(ctx, f) -} - -func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { - return c.Conn().BeginTxFunc(ctx, txOptions, f) -} - func (c *Conn) Ping(ctx context.Context) error { return c.Conn().Ping(ctx) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index d73b93fb..7027e282 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -570,20 +570,6 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return &Tx{t: t, c: c}, err } -func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return p.BeginTxFunc(ctx, pgx.TxOptions{}, f) -} - -func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { - c, err := p.Acquire(ctx) - if err != nil { - return err - } - defer c.Release() - - return c.BeginTxFunc(ctx, txOptions, f) -} - func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { c, err := p.Acquire(ctx) if err != nil { diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 3e3058d2..5cd943d7 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -806,15 +806,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { db.Exec(context.Background(), "drop table pgxpooltx") }() - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") require.NoError(t, err) return nil @@ -853,11 +853,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { db.Exec(context.Background(), "drop table pgxpooltx") }() - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 3ddb742c..74df8593 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -18,10 +18,6 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } -func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { - return tx.t.BeginFunc(ctx, f) -} - // Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed // if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status // (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. diff --git a/tx.go b/tx.go index 2a05b70d..24daf0f8 100644 --- a/tx.go +++ b/tx.go @@ -94,39 +94,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { return &dbTx{conn: c}, nil } -// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns -// an error the transaction is rolled back. The context will be used when executing the transaction control statements -// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. -func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - return c.BeginTxFunc(ctx, TxOptions{}, f) -} - -// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return -// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be -// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect -// the execution of f. -func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { - var tx Tx - tx, err = c.BeginTx(ctx, txOptions) - if err != nil { - return err - } - defer func() { - rollbackErr := tx.Rollback(ctx) - if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { - err = rollbackErr - } - }() - - fErr := f(tx) - if fErr != nil { - _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return - return fErr - } - - return tx.Commit(ctx) -} - // Tx represents a database transaction. // // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx @@ -138,20 +105,17 @@ type Tx interface { // Begin starts a pseudo nested transaction. Begin(ctx context.Context) (Tx, error) - // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested - // transaction will be committed. If it does then it will be rolled back. - BeginFunc(ctx context.Context, f func(Tx) error) (err error) - // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested - // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple - // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then - // ErrTxCommitRollback will be returned. + // transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is + // otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already + // in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned. Commit(ctx context.Context) error // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a - // pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to - // call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error - // condition. Any other failure of a real transaction will result in the connection being closed. + // pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already + // closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will + // be called first in a non-error condition. Any other failure of a real transaction will result in the connection + // being closed. Rollback(ctx context.Context) error CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) @@ -194,32 +158,6 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil } -func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - if tx.closed { - return ErrTxClosed - } - - var savepoint Tx - savepoint, err = tx.Begin(ctx) - if err != nil { - return err - } - defer func() { - rollbackErr := savepoint.Rollback(ctx) - if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { - err = rollbackErr - } - }() - - fErr := f(savepoint) - if fErr != nil { - _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return - return fErr - } - - return savepoint.Commit(ctx) -} - // Commit commits the transaction. func (tx *dbTx) Commit(ctx context.Context) error { if tx.closed { @@ -335,14 +273,6 @@ func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) { return sp.tx.Begin(ctx) } -func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { - if sp.closed { - return ErrTxClosed - } - - return sp.tx.BeginFunc(ctx, f) -} - // Commit releases the savepoint essentially committing the pseudo nested transaction. func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { if sp.closed { @@ -427,3 +357,59 @@ func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects { func (sp *dbSimulatedNestedTx) Conn() *Conn { return sp.tx.Conn() } + +// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginFunc( + ctx context.Context, + db interface { + Begin(ctx context.Context) (Tx, error) + }, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.Begin(ctx) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) +} + +// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn +// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn. +func BeginTxFunc( + ctx context.Context, + db interface { + BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) + }, + txOptions TxOptions, + fn func(Tx) error, +) (err error) { + var tx Tx + tx, err = db.BeginTx(ctx, txOptions) + if err != nil { + return err + } + + return beginFuncExec(ctx, tx, fn) +} + +func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) { + defer func() { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { + err = rollbackErr + } + }() + + fErr := fn(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) +} diff --git a/tx_test.go b/tx_test.go index d45553a2..9c1c70d3 100644 --- a/tx_test.go +++ b/tx_test.go @@ -312,7 +312,7 @@ func TestBeginFunc(t *testing.T) { _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return nil @@ -341,7 +341,7 @@ func TestBeginFuncRollbackOnError(t *testing.T) { _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) - err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return errors.New("some error") @@ -522,15 +522,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") require.NoError(t, err) return nil @@ -565,11 +565,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) - err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) return errors.New("do a rollback")