From ac2918b9a32454f04833395a7be31d1e136782f8 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 20 Feb 2021 18:30:18 -0600
Subject: [PATCH] Add BeginFunc and BeginTxFunc

fixes #821
---
 doc.go               |  11 ++++
 pgxpool/conn.go      |   8 +++
 pgxpool/pool.go      |  14 +++++
 pgxpool/pool_test.go |  91 ++++++++++++++++++++++++++++
 pgxpool/tx.go        |   4 ++
 tx.go                |  71 ++++++++++++++++++++++
 tx_test.go           | 141 +++++++++++++++++++++++++++++++++++++++++++
 7 files changed, 340 insertions(+)

diff --git a/doc.go b/doc.go
index b27708d6..51b0d9f4 100644
--- a/doc.go
+++ b/doc.go
@@ -252,6 +252,17 @@ These are internally implemented with savepoints.
 
 Use BeginTx to control the transaction mode.
 
+BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the
+transaction depending on the return value of the function. These can be simpler and less error prone to use.
+
+    err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
+        _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
+        return err
+    })
+    if err != nil {
+        return err
+    }
+
 Prepared Statements
 
 Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
diff --git a/pgxpool/conn.go b/pgxpool/conn.go
index 4bd4bb9f..29ca04d0 100644
--- a/pgxpool/conn.go
+++ b/pgxpool/conn.go
@@ -78,6 +78,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
 	return c.Conn().BeginTx(ctx, txOptions)
 }
 
+func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
+	return c.Conn().BeginFunc(ctx, f)
+}
+
+func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
+	return c.Conn().BeginTxFunc(ctx, txOptions, f)
+}
+
 func (c *Conn) Ping(ctx context.Context) error {
 	return c.Conn().Ping(ctx)
 }
diff --git a/pgxpool/pool.go b/pgxpool/pool.go
index 8efb9265..09752aaa 100644
--- a/pgxpool/pool.go
+++ b/pgxpool/pool.go
@@ -496,6 +496,20 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
 	return &Tx{t: t, c: c}, err
 }
 
+func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
+	return p.BeginTxFunc(ctx, pgx.TxOptions{}, f)
+}
+
+func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
+	c, err := p.Acquire(ctx)
+	if err != nil {
+		return err
+	}
+	defer c.Release()
+
+	return c.BeginTxFunc(ctx, txOptions, f)
+}
+
 func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
 	c, err := p.Acquire(ctx)
 	if err != nil {
diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go
index 12f92c0a..85f59256 100644
--- a/pgxpool/pool_test.go
+++ b/pgxpool/pool_test.go
@@ -2,6 +2,7 @@ package pgxpool_test
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"os"
 	"testing"
@@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) {
 
 	assert.EqualValues(t, 0, db.Stat().TotalConns())
 }
+
+func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
+	db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer db.Close()
+
+	createSql := `
+		drop table if exists pgxpooltx;
+    create temporary table pgxpooltx(
+      id integer,
+      unique (id) initially deferred
+    );
+  `
+
+	_, err = db.Exec(context.Background(), createSql)
+	require.NoError(t, err)
+
+	defer func() {
+		db.Exec(context.Background(), "drop table pgxpooltx")
+	}()
+
+	err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+		_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
+		require.NoError(t, err)
+
+		err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+			_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
+			require.NoError(t, err)
+
+			err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+				_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
+				require.NoError(t, err)
+				return nil
+			})
+
+			return nil
+		})
+		require.NoError(t, err)
+		return nil
+	})
+	require.NoError(t, err)
+
+	var n int64
+	err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
+	require.NoError(t, err)
+	require.EqualValues(t, 3, n)
+}
+
+func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
+	db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer db.Close()
+
+	createSql := `
+		drop table if exists pgxpooltx;
+    create temporary table pgxpooltx(
+      id integer,
+      unique (id) initially deferred
+    );
+  `
+
+	_, err = db.Exec(context.Background(), createSql)
+	require.NoError(t, err)
+
+	defer func() {
+		db.Exec(context.Background(), "drop table pgxpooltx")
+	}()
+
+	err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+		_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
+		require.NoError(t, err)
+
+		err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+			_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
+			require.NoError(t, err)
+			return errors.New("do a rollback")
+		})
+		require.EqualError(t, err, "do a rollback")
+
+		_, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
+		require.NoError(t, err)
+
+		return nil
+	})
+
+	var n int64
+	err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
+	require.NoError(t, err)
+	require.EqualValues(t, 2, n)
+}
diff --git a/pgxpool/tx.go b/pgxpool/tx.go
index 15e0ee2d..e1c980e1 100644
--- a/pgxpool/tx.go
+++ b/pgxpool/tx.go
@@ -16,6 +16,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) {
 	return tx.t.Begin(ctx)
 }
 
+func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
+	return tx.t.BeginFunc(ctx, f)
+}
+
 func (tx *Tx) Commit(ctx context.Context) error {
 	err := tx.t.Commit(ctx)
 	if tx.c != nil {
diff --git a/tx.go b/tx.go
index 43f8aa3e..5ba9836a 100644
--- a/tx.go
+++ b/tx.go
@@ -85,6 +85,39 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
 	return &dbTx{conn: c}, nil
 }
 
+// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns
+// an error the transaction is rolled back. The context will be used when executing the transaction control statements
+// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f.
+func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
+	return c.BeginTxFunc(ctx, TxOptions{}, f)
+}
+
+// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return
+// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be
+// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect
+// the execution of f.
+func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
+	var tx Tx
+	tx, err = c.BeginTx(ctx, TxOptions{})
+	if err != nil {
+		return err
+	}
+	defer func() {
+		rollbackErr := tx.Rollback(ctx)
+		if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
+			err = rollbackErr
+		}
+	}()
+
+	fErr := f(tx)
+	if fErr != nil {
+		_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
+		return fErr
+	}
+
+	return tx.Commit(ctx)
+}
+
 // Tx represents a database transaction.
 //
 // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
@@ -96,6 +129,10 @@ type Tx interface {
 	// Begin starts a pseudo nested transaction.
 	Begin(ctx context.Context) (Tx, error)
 
+	// BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested
+	// transaction will be committed. If it does then it will be rolled back.
+	BeginFunc(ctx context.Context, f func(Tx) error) (err error)
+
 	// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
 	// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
 	// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
@@ -149,6 +186,32 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
 	return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil
 }
 
+func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
+	if tx.closed {
+		return ErrTxClosed
+	}
+
+	var savepoint Tx
+	savepoint, err = tx.Begin(ctx)
+	if err != nil {
+		return err
+	}
+	defer func() {
+		rollbackErr := savepoint.Rollback(ctx)
+		if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
+			err = rollbackErr
+		}
+	}()
+
+	fErr := f(savepoint)
+	if fErr != nil {
+		_ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return
+		return fErr
+	}
+
+	return savepoint.Commit(ctx)
+}
+
 // Commit commits the transaction.
 func (tx *dbTx) Commit(ctx context.Context) error {
 	if tx.closed {
@@ -273,6 +336,14 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) {
 	return sp.tx.Begin(ctx)
 }
 
+func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
+	if sp.closed {
+		return ErrTxClosed
+	}
+
+	return sp.tx.BeginFunc(ctx, f)
+}
+
 // Commit releases the savepoint essentially committing the pseudo nested transaction.
 func (sp *dbSavepoint) Commit(ctx context.Context) error {
 	if sp.closed {
diff --git a/tx_test.go b/tx_test.go
index e0928c1b..901052c2 100644
--- a/tx_test.go
+++ b/tx_test.go
@@ -2,6 +2,7 @@ package pgx_test
 
 import (
 	"context"
+	"errors"
 	"os"
 	"testing"
 
@@ -282,6 +283,64 @@ func TestBeginIsoLevels(t *testing.T) {
 	}
 }
 
+func TestBeginFunc(t *testing.T) {
+	t.Parallel()
+
+	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
+	defer closeConn(t, conn)
+
+	createSql := `
+    create temporary table foo(
+      id integer,
+      unique (id) initially deferred
+    );
+  `
+
+	_, err := conn.Exec(context.Background(), createSql)
+	require.NoError(t, err)
+
+	err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
+		_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
+		require.NoError(t, err)
+		return nil
+	})
+	require.NoError(t, err)
+
+	var n int64
+	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
+	require.NoError(t, err)
+	require.EqualValues(t, 1, n)
+}
+
+func TestBeginFuncRollbackOnError(t *testing.T) {
+	t.Parallel()
+
+	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
+	defer closeConn(t, conn)
+
+	createSql := `
+    create temporary table foo(
+      id integer,
+      unique (id) initially deferred
+    );
+  `
+
+	_, err := conn.Exec(context.Background(), createSql)
+	require.NoError(t, err)
+
+	err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
+		_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
+		require.NoError(t, err)
+		return errors.New("some error")
+	})
+	require.EqualError(t, err, "some error")
+
+	var n int64
+	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
+	require.NoError(t, err)
+	require.EqualValues(t, 0, n)
+}
+
 func TestBeginReadOnly(t *testing.T) {
 	t.Parallel()
 
@@ -433,3 +492,85 @@ func TestTxNestedTransactionRollback(t *testing.T) {
 		t.Fatalf("Did not receive correct number of rows: %v", n)
 	}
 }
+
+func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
+	t.Parallel()
+
+	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
+	defer closeConn(t, db)
+
+	createSql := `
+    create temporary table foo(
+      id integer,
+      unique (id) initially deferred
+    );
+  `
+
+	_, err := db.Exec(context.Background(), createSql)
+	require.NoError(t, err)
+
+	err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+		_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
+		require.NoError(t, err)
+
+		err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+			_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
+			require.NoError(t, err)
+
+			err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+				_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
+				require.NoError(t, err)
+				return nil
+			})
+
+			return nil
+		})
+		require.NoError(t, err)
+		return nil
+	})
+	require.NoError(t, err)
+
+	var n int64
+	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
+	require.NoError(t, err)
+	require.EqualValues(t, 3, n)
+}
+
+func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
+	t.Parallel()
+
+	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
+	defer closeConn(t, db)
+
+	createSql := `
+    create temporary table foo(
+      id integer,
+      unique (id) initially deferred
+    );
+  `
+
+	_, err := db.Exec(context.Background(), createSql)
+	require.NoError(t, err)
+
+	err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+		_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
+		require.NoError(t, err)
+
+		err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
+			_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
+			require.NoError(t, err)
+			return errors.New("do a rollback")
+		})
+		require.EqualError(t, err, "do a rollback")
+
+		_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
+		require.NoError(t, err)
+
+		return nil
+	})
+
+	var n int64
+	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
+	require.NoError(t, err)
+	require.EqualValues(t, 2, n)
+}