diff --git a/doc.go b/doc.go index b27708d6..51b0d9f4 100644 --- a/doc.go +++ b/doc.go @@ -252,6 +252,17 @@ These are internally implemented with savepoints. Use BeginTx to control the transaction mode. +BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the +transaction depending on the return value of the function. These can be simpler and less error prone to use. + + err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") + return err + }) + if err != nil { + return err + } + Prepared Statements Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 4bd4bb9f..29ca04d0 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -78,6 +78,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return c.Conn().BeginTx(ctx, txOptions) } +func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { + return c.Conn().BeginFunc(ctx, f) +} + +func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { + return c.Conn().BeginTxFunc(ctx, txOptions, f) +} + func (c *Conn) Ping(ctx context.Context) error { return c.Conn().Ping(ctx) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 8efb9265..09752aaa 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -496,6 +496,20 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er return &Tx{t: t, c: c}, err } +func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { + return p.BeginTxFunc(ctx, pgx.TxOptions{}, f) +} + +func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { + c, err := p.Acquire(ctx) + if err != nil { + return err + } + defer c.Release() + + return c.BeginTxFunc(ctx, txOptions, f) +} + func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { c, err := p.Acquire(ctx) if err != nil { diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 12f92c0a..85f59256 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -2,6 +2,7 @@ package pgxpool_test import ( "context" + "errors" "fmt" "os" "testing" @@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) { assert.EqualValues(t, 0, db.Stat().TotalConns()) } + +func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { + db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + createSql := ` + drop table if exists pgxpooltx; + create temporary table pgxpooltx( + id integer, + unique (id) initially deferred + ); + ` + + _, err = db.Exec(context.Background(), createSql) + require.NoError(t, err) + + defer func() { + db.Exec(context.Background(), "drop table pgxpooltx") + }() + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") + require.NoError(t, err) + return nil + }) + + return nil + }) + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + var n int64 + err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) +} + +func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { + db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer db.Close() + + createSql := ` + drop table if exists pgxpooltx; + create temporary table pgxpooltx( + id integer, + unique (id) initially deferred + ); + ` + + _, err = db.Exec(context.Background(), createSql) + require.NoError(t, err) + + defer func() { + db.Exec(context.Background(), "drop table pgxpooltx") + }() + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") + require.NoError(t, err) + return errors.New("do a rollback") + }) + require.EqualError(t, err, "do a rollback") + + _, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") + require.NoError(t, err) + + return nil + }) + + var n int64 + err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) +} diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 15e0ee2d..e1c980e1 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -16,6 +16,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } +func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { + return tx.t.BeginFunc(ctx, f) +} + func (tx *Tx) Commit(ctx context.Context) error { err := tx.t.Commit(ctx) if tx.c != nil { diff --git a/tx.go b/tx.go index 43f8aa3e..5ba9836a 100644 --- a/tx.go +++ b/tx.go @@ -85,6 +85,39 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { return &dbTx{conn: c}, nil } +// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns +// an error the transaction is rolled back. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. +func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { + return c.BeginTxFunc(ctx, TxOptions{}, f) +} + +// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return +// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be +// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect +// the execution of f. +func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { + var tx Tx + tx, err = c.BeginTx(ctx, TxOptions{}) + if err != nil { + return err + } + defer func() { + rollbackErr := tx.Rollback(ctx) + if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) { + err = rollbackErr + } + }() + + fErr := f(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) +} + // Tx represents a database transaction. // // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx @@ -96,6 +129,10 @@ type Tx interface { // Begin starts a pseudo nested transaction. Begin(ctx context.Context) (Tx, error) + // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested + // transaction will be committed. If it does then it will be rolled back. + BeginFunc(ctx context.Context, f func(Tx) error) (err error) + // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then @@ -149,6 +186,32 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil } +func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { + if tx.closed { + return ErrTxClosed + } + + var savepoint Tx + savepoint, err = tx.Begin(ctx) + if err != nil { + return err + } + defer func() { + rollbackErr := savepoint.Rollback(ctx) + if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) { + err = rollbackErr + } + }() + + fErr := f(savepoint) + if fErr != nil { + _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return savepoint.Commit(ctx) +} + // Commit commits the transaction. func (tx *dbTx) Commit(ctx context.Context) error { if tx.closed { @@ -273,6 +336,14 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) { return sp.tx.Begin(ctx) } +func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { + if sp.closed { + return ErrTxClosed + } + + return sp.tx.BeginFunc(ctx, f) +} + // Commit releases the savepoint essentially committing the pseudo nested transaction. func (sp *dbSavepoint) Commit(ctx context.Context) error { if sp.closed { diff --git a/tx_test.go b/tx_test.go index e0928c1b..901052c2 100644 --- a/tx_test.go +++ b/tx_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "context" + "errors" "os" "testing" @@ -282,6 +283,64 @@ func TestBeginIsoLevels(t *testing.T) { } } +func TestBeginFunc(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + _, err := conn.Exec(context.Background(), createSql) + require.NoError(t, err) + + err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) +} + +func TestBeginFuncRollbackOnError(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + _, err := conn.Exec(context.Background(), createSql) + require.NoError(t, err) + + err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) + return errors.New("some error") + }) + require.EqualError(t, err, "some error") + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 0, n) +} + func TestBeginReadOnly(t *testing.T) { t.Parallel() @@ -433,3 +492,85 @@ func TestTxNestedTransactionRollback(t *testing.T) { t.Fatalf("Did not receive correct number of rows: %v", n) } } + +func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { + t.Parallel() + + db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, db) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + _, err := db.Exec(context.Background(), createSql) + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") + require.NoError(t, err) + return nil + }) + + return nil + }) + require.NoError(t, err) + return nil + }) + require.NoError(t, err) + + var n int64 + err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) +} + +func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { + t.Parallel() + + db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, db) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + _, err := db.Exec(context.Background(), createSql) + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") + require.NoError(t, err) + + err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { + _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") + require.NoError(t, err) + return errors.New("do a rollback") + }) + require.EqualError(t, err, "do a rollback") + + _, err = db.Exec(context.Background(), "insert into foo(id) values (3)") + require.NoError(t, err) + + return nil + }) + + var n int64 + err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) +}