Replace Begin and BeginTx methods with functions

pull/1281/head
Jack Christensen 2022-07-09 17:25:55 -05:00
parent 62f0347586
commit 31ec18cc65
8 changed files with 82 additions and 117 deletions

View File

@ -141,6 +141,11 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
* `CollectOneRow` collects one row using `RowTo*` functions. * `CollectOneRow` collects one row using `RowTo*` functions.
* `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. * `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`.
## Tx Helpers
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
## SendBatch Uses Pipeline Mode When Appropriate ## SendBatch Uses Pipeline Mode When Appropriate
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1

4
doc.go
View File

@ -247,10 +247,10 @@ These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. Use BeginTx to control the transaction mode.
BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use. transaction depending on the return value of the function. These can be simpler and less error prone to use.
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
return err return err
}) })

View File

@ -92,14 +92,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
return c.Conn().BeginTx(ctx, txOptions) return c.Conn().BeginTx(ctx, txOptions)
} }
func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return c.Conn().BeginFunc(ctx, f)
}
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
return c.Conn().BeginTxFunc(ctx, txOptions, f)
}
func (c *Conn) Ping(ctx context.Context) error { func (c *Conn) Ping(ctx context.Context) error {
return c.Conn().Ping(ctx) return c.Conn().Ping(ctx)
} }

View File

@ -570,20 +570,6 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
return &Tx{t: t, c: c}, err return &Tx{t: t, c: c}, err
} }
func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return p.BeginTxFunc(ctx, pgx.TxOptions{}, f)
}
func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
c, err := p.Acquire(ctx)
if err != nil {
return err
}
defer c.Release()
return c.BeginTxFunc(ctx, txOptions, f)
}
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
c, err := p.Acquire(ctx) c, err := p.Acquire(ctx)
if err != nil { if err != nil {

View File

@ -806,15 +806,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
db.Exec(context.Background(), "drop table pgxpooltx") db.Exec(context.Background(), "drop table pgxpooltx")
}() }()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err) require.NoError(t, err)
return nil return nil
@ -853,11 +853,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
db.Exec(context.Background(), "drop table pgxpooltx") db.Exec(context.Background(), "drop table pgxpooltx")
}() }()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err) require.NoError(t, err)
return errors.New("do a rollback") return errors.New("do a rollback")

View File

@ -18,10 +18,6 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) {
return tx.t.Begin(ctx) return tx.t.Begin(ctx)
} }
func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return tx.t.BeginFunc(ctx, f)
}
// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed // Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed
// if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status // if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status
// (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. // (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned.

140
tx.go
View File

@ -94,39 +94,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
return &dbTx{conn: c}, nil return &dbTx{conn: c}, nil
} }
// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns
// an error the transaction is rolled back. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f.
func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
return c.BeginTxFunc(ctx, TxOptions{}, f)
}
// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return
// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be
// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect
// the execution of f.
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
var tx Tx
tx, err = c.BeginTx(ctx, txOptions)
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
err = rollbackErr
}
}()
fErr := f(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return tx.Commit(ctx)
}
// Tx represents a database transaction. // Tx represents a database transaction.
// //
// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
@ -138,20 +105,17 @@ type Tx interface {
// Begin starts a pseudo nested transaction. // Begin starts a pseudo nested transaction.
Begin(ctx context.Context) (Tx, error) Begin(ctx context.Context) (Tx, error)
// BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested
// transaction will be committed. If it does then it will be rolled back.
BeginFunc(ctx context.Context, f func(Tx) error) (err error)
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple // transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then // otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already
// ErrTxCommitRollback will be returned. // in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned.
Commit(ctx context.Context) error Commit(ctx context.Context) error
// Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a
// pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to // pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already
// call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error // closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will
// condition. Any other failure of a real transaction will result in the connection being closed. // be called first in a non-error condition. Any other failure of a real transaction will result in the connection
// being closed.
Rollback(ctx context.Context) error Rollback(ctx context.Context) error
CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error)
@ -194,32 +158,6 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil
} }
func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if tx.closed {
return ErrTxClosed
}
var savepoint Tx
savepoint, err = tx.Begin(ctx)
if err != nil {
return err
}
defer func() {
rollbackErr := savepoint.Rollback(ctx)
if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
err = rollbackErr
}
}()
fErr := f(savepoint)
if fErr != nil {
_ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return savepoint.Commit(ctx)
}
// Commit commits the transaction. // Commit commits the transaction.
func (tx *dbTx) Commit(ctx context.Context) error { func (tx *dbTx) Commit(ctx context.Context) error {
if tx.closed { if tx.closed {
@ -335,14 +273,6 @@ func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) {
return sp.tx.Begin(ctx) return sp.tx.Begin(ctx)
} }
func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if sp.closed {
return ErrTxClosed
}
return sp.tx.BeginFunc(ctx, f)
}
// Commit releases the savepoint essentially committing the pseudo nested transaction. // Commit releases the savepoint essentially committing the pseudo nested transaction.
func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error {
if sp.closed { if sp.closed {
@ -427,3 +357,59 @@ func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects {
func (sp *dbSimulatedNestedTx) Conn() *Conn { func (sp *dbSimulatedNestedTx) Conn() *Conn {
return sp.tx.Conn() return sp.tx.Conn()
} }
// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
func BeginFunc(
ctx context.Context,
db interface {
Begin(ctx context.Context) (Tx, error)
},
fn func(Tx) error,
) (err error) {
var tx Tx
tx, err = db.Begin(ctx)
if err != nil {
return err
}
return beginFuncExec(ctx, tx, fn)
}
// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
func BeginTxFunc(
ctx context.Context,
db interface {
BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error)
},
txOptions TxOptions,
fn func(Tx) error,
) (err error) {
var tx Tx
tx, err = db.BeginTx(ctx, txOptions)
if err != nil {
return err
}
return beginFuncExec(ctx, tx, fn)
}
func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) {
defer func() {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
err = rollbackErr
}
}()
fErr := fn(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return tx.Commit(ctx)
}

View File

@ -312,7 +312,7 @@ func TestBeginFunc(t *testing.T) {
_, err := conn.Exec(context.Background(), createSql) _, err := conn.Exec(context.Background(), createSql)
require.NoError(t, err) require.NoError(t, err)
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err) require.NoError(t, err)
return nil return nil
@ -341,7 +341,7 @@ func TestBeginFuncRollbackOnError(t *testing.T) {
_, err := conn.Exec(context.Background(), createSql) _, err := conn.Exec(context.Background(), createSql)
require.NoError(t, err) require.NoError(t, err)
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err) require.NoError(t, err)
return errors.New("some error") return errors.New("some error")
@ -522,15 +522,15 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
_, err := db.Exec(context.Background(), createSql) _, err := db.Exec(context.Background(), createSql)
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)") _, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)") _, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (3)") _, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
require.NoError(t, err) require.NoError(t, err)
return nil return nil
@ -565,11 +565,11 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
_, err := db.Exec(context.Background(), createSql) _, err := db.Exec(context.Background(), createSql)
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)") _, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err) require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)") _, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
require.NoError(t, err) require.NoError(t, err)
return errors.New("do a rollback") return errors.New("do a rollback")