Add BeginFunc and BeginTxFunc

fixes #821
pull/955/head
Jack Christensen 2021-02-20 18:30:18 -06:00
parent 373bb84e9d
commit ac2918b9a3
7 changed files with 340 additions and 0 deletions

11
doc.go
View File

@ -252,6 +252,17 @@ These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. Use BeginTx to control the transaction mode.
BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use.
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
return err
})
if err != nil {
return err
}
Prepared Statements Prepared Statements
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx

View File

@ -78,6 +78,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
return c.Conn().BeginTx(ctx, txOptions) return c.Conn().BeginTx(ctx, txOptions)
} }
func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return c.Conn().BeginFunc(ctx, f)
}
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
return c.Conn().BeginTxFunc(ctx, txOptions, f)
}
func (c *Conn) Ping(ctx context.Context) error { func (c *Conn) Ping(ctx context.Context) error {
return c.Conn().Ping(ctx) return c.Conn().Ping(ctx)
} }

View File

@ -496,6 +496,20 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
return &Tx{t: t, c: c}, err return &Tx{t: t, c: c}, err
} }
func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return p.BeginTxFunc(ctx, pgx.TxOptions{}, f)
}
func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
c, err := p.Acquire(ctx)
if err != nil {
return err
}
defer c.Release()
return c.BeginTxFunc(ctx, txOptions, f)
}
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
c, err := p.Acquire(ctx) c, err := p.Acquire(ctx)
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package pgxpool_test
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) {
assert.EqualValues(t, 0, db.Stat().TotalConns()) assert.EqualValues(t, 0, db.Stat().TotalConns())
} }
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer db.Close()
createSql := `
drop table if exists pgxpooltx;
create temporary table pgxpooltx(
id integer,
unique (id) initially deferred
);
`
_, err = db.Exec(context.Background(), createSql)
require.NoError(t, err)
defer func() {
db.Exec(context.Background(), "drop table pgxpooltx")
}()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err)
return nil
})
return nil
})
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
}
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer db.Close()
createSql := `
drop table if exists pgxpooltx;
create temporary table pgxpooltx(
id integer,
unique (id) initially deferred
);
`
_, err = db.Exec(context.Background(), createSql)
require.NoError(t, err)
defer func() {
db.Exec(context.Background(), "drop table pgxpooltx")
}()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err)
return errors.New("do a rollback")
})
require.EqualError(t, err, "do a rollback")
_, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err)
return nil
})
var n int64
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
}

View File

@ -16,6 +16,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) {
return tx.t.Begin(ctx) return tx.t.Begin(ctx)
} }
func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
return tx.t.BeginFunc(ctx, f)
}
func (tx *Tx) Commit(ctx context.Context) error { func (tx *Tx) Commit(ctx context.Context) error {
err := tx.t.Commit(ctx) err := tx.t.Commit(ctx)
if tx.c != nil { if tx.c != nil {

71
tx.go
View File

@ -85,6 +85,39 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
return &dbTx{conn: c}, nil return &dbTx{conn: c}, nil
} }
// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns
// an error the transaction is rolled back. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f.
func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
return c.BeginTxFunc(ctx, TxOptions{}, f)
}
// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return
// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be
// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect
// the execution of f.
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
var tx Tx
tx, err = c.BeginTx(ctx, TxOptions{})
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback(ctx)
if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
err = rollbackErr
}
}()
fErr := f(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return tx.Commit(ctx)
}
// Tx represents a database transaction. // Tx represents a database transaction.
// //
// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
@ -96,6 +129,10 @@ type Tx interface {
// Begin starts a pseudo nested transaction. // Begin starts a pseudo nested transaction.
Begin(ctx context.Context) (Tx, error) Begin(ctx context.Context) (Tx, error)
// BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested
// transaction will be committed. If it does then it will be rolled back.
BeginFunc(ctx context.Context, f func(Tx) error) (err error)
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
@ -149,6 +186,32 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil
} }
func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if tx.closed {
return ErrTxClosed
}
var savepoint Tx
savepoint, err = tx.Begin(ctx)
if err != nil {
return err
}
defer func() {
rollbackErr := savepoint.Rollback(ctx)
if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
err = rollbackErr
}
}()
fErr := f(savepoint)
if fErr != nil {
_ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return savepoint.Commit(ctx)
}
// Commit commits the transaction. // Commit commits the transaction.
func (tx *dbTx) Commit(ctx context.Context) error { func (tx *dbTx) Commit(ctx context.Context) error {
if tx.closed { if tx.closed {
@ -273,6 +336,14 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) {
return sp.tx.Begin(ctx) return sp.tx.Begin(ctx)
} }
func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if sp.closed {
return ErrTxClosed
}
return sp.tx.BeginFunc(ctx, f)
}
// Commit releases the savepoint essentially committing the pseudo nested transaction. // Commit releases the savepoint essentially committing the pseudo nested transaction.
func (sp *dbSavepoint) Commit(ctx context.Context) error { func (sp *dbSavepoint) Commit(ctx context.Context) error {
if sp.closed { if sp.closed {

View File

@ -2,6 +2,7 @@ package pgx_test
import ( import (
"context" "context"
"errors"
"os" "os"
"testing" "testing"
@ -282,6 +283,64 @@ func TestBeginIsoLevels(t *testing.T) {
} }
} }
func TestBeginFunc(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := conn.Exec(context.Background(), createSql)
require.NoError(t, err)
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
}
func TestBeginFuncRollbackOnError(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := conn.Exec(context.Background(), createSql)
require.NoError(t, err)
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
return errors.New("some error")
})
require.EqualError(t, err, "some error")
var n int64
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 0, n)
}
func TestBeginReadOnly(t *testing.T) { func TestBeginReadOnly(t *testing.T) {
t.Parallel() t.Parallel()
@ -433,3 +492,85 @@ func TestTxNestedTransactionRollback(t *testing.T) {
t.Fatalf("Did not receive correct number of rows: %v", n) t.Fatalf("Did not receive correct number of rows: %v", n)
} }
} }
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
t.Parallel()
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, db)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := db.Exec(context.Background(), createSql)
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
require.NoError(t, err)
return nil
})
return nil
})
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
}
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
t.Parallel()
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, db)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
_, err := db.Exec(context.Background(), createSql)
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
require.NoError(t, err)
return errors.New("do a rollback")
})
require.EqualError(t, err, "do a rollback")
_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
require.NoError(t, err)
return nil
})
var n int64
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
}