Add BeginFunc and BeginTxFunc

fixes #821
pull/955/head
Jack Christensen 2021-02-20 18:30:18 -06:00
parent 373bb84e9d
commit ac2918b9a3
7 changed files with 340 additions and 0 deletions

11
doc.go
View File

@ -252,6 +252,17 @@ These are internally implemented with savepoints.
Use BeginTx to control the transaction mode.
BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use.
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
return err
})
if err != nil {
return err
}
Prepared Statements
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx

View File

@ -78,6 +78,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
return c.Conn().BeginTx(ctx, txOptions)
}
func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return c.Conn().BeginFunc(ctx, f)
}
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
return c.Conn().BeginTxFunc(ctx, txOptions, f)
}
func (c *Conn) Ping(ctx context.Context) error {
return c.Conn().Ping(ctx)
}

View File

@ -496,6 +496,20 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
return &Tx{t: t, c: c}, err
}
func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return p.BeginTxFunc(ctx, pgx.TxOptions{}, f)
}
func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
c, err := p.Acquire(ctx)
if err != nil {
return err
}
defer c.Release()
return c.BeginTxFunc(ctx, txOptions, f)
}
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
c, err := p.Acquire(ctx)
if err != nil {

View File

@ -2,6 +2,7 @@ package pgxpool_test
import (
"context"
"errors"
"fmt"
"os"
"testing"
@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) {
assert.EqualValues(t, 0, db.Stat().TotalConns())
}
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer db.Close()
createSql := `
drop table if exists pgxpooltx;
create temporary table pgxpooltx(
id integer,
unique (id) initially deferred
);
`
_, err = db.Exec(context.Background(), createSql)
require.NoError(t, err)
defer func() {
db.Exec(context.Background(), "drop table pgxpooltx")
}()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err)
return nil
})
return nil
})
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
}
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer db.Close()
createSql := `
drop table if exists pgxpooltx;
create temporary table pgxpooltx(
id integer,
unique (id) initially deferred
);
`
_, err = db.Exec(context.Background(), createSql)
require.NoError(t, err)
defer func() {
db.Exec(context.Background(), "drop table pgxpooltx")
}()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err)
return errors.New("do a rollback")
})
require.EqualError(t, err, "do a rollback")
_, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err)
return nil
})
var n int64
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
}

View File

@ -16,6 +16,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) {
return tx.t.Begin(ctx)
}
func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return tx.t.BeginFunc(ctx, f)
}
func (tx *Tx) Commit(ctx context.Context) error {
err := tx.t.Commit(ctx)
if tx.c != nil {

71
tx.go
View File

@ -85,6 +85,39 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
return &dbTx{conn: c}, nil
}
// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns
// an error the transaction is rolled back. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f.
func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
return c.BeginTxFunc(ctx, TxOptions{}, f)
}
// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return
// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be
// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect
// the execution of f.
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
var tx Tx
tx, err = c.BeginTx(ctx, TxOptions{})
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback(ctx)
if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
err = rollbackErr
}
}()
fErr := f(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return tx.Commit(ctx)
}
// Tx represents a database transaction.
//
// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
@ -96,6 +129,10 @@ type Tx interface {
// Begin starts a pseudo nested transaction.
Begin(ctx context.Context) (Tx, error)
// BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested
// transaction will be committed. If it does then it will be rolled back.
BeginFunc(ctx context.Context, f func(Tx) error) (err error)
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
@ -149,6 +186,32 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil
}
func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if tx.closed {
return ErrTxClosed
}
var savepoint Tx
savepoint, err = tx.Begin(ctx)
if err != nil {
return err
}
defer func() {
rollbackErr := savepoint.Rollback(ctx)
if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
err = rollbackErr
}
}()
fErr := f(savepoint)
if fErr != nil {
_ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return savepoint.Commit(ctx)
}
// Commit commits the transaction.
func (tx *dbTx) Commit(ctx context.Context) error {
if tx.closed {
@ -273,6 +336,14 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) {
return sp.tx.Begin(ctx)
}
func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if sp.closed {
return ErrTxClosed
}
return sp.tx.BeginFunc(ctx, f)
}
// Commit releases the savepoint essentially committing the pseudo nested transaction.
func (sp *dbSavepoint) Commit(ctx context.Context) error {
if sp.closed {

View File

@ -2,6 +2,7 @@ package pgx_test
import (
"context"
"errors"
"os"
"testing"
@ -282,6 +283,64 @@ func TestBeginIsoLevels(t *testing.T) {
}
}
func TestBeginFunc(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := conn.Exec(context.Background(), createSql)
require.NoError(t, err)
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
}
func TestBeginFuncRollbackOnError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := conn.Exec(context.Background(), createSql)
require.NoError(t, err)
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
return errors.New("some error")
})
require.EqualError(t, err, "some error")
var n int64
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 0, n)
}
func TestBeginReadOnly(t *testing.T) {
t.Parallel()
@ -433,3 +492,85 @@ func TestTxNestedTransactionRollback(t *testing.T) {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
}
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
t.Parallel()
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, db)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := db.Exec(context.Background(), createSql)
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
require.NoError(t, err)
return nil
})
return nil
})
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
}
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
t.Parallel()
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, db)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := db.Exec(context.Background(), createSql)
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
require.NoError(t, err)
return errors.New("do a rollback")
})
require.EqualError(t, err, "do a rollback")
_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
require.NoError(t, err)
return nil
})
var n int64
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
}