mirror of https://github.com/jackc/pgx.git
parent
373bb84e9d
commit
ac2918b9a3
11
doc.go
11
doc.go
|
@ -252,6 +252,17 @@ These are internally implemented with savepoints.
|
||||||
|
|
||||||
Use BeginTx to control the transaction mode.
|
Use BeginTx to control the transaction mode.
|
||||||
|
|
||||||
|
BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the
|
||||||
|
transaction depending on the return value of the function. These can be simpler and less error prone to use.
|
||||||
|
|
||||||
|
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
Prepared Statements
|
Prepared Statements
|
||||||
|
|
||||||
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
|
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
|
||||||
|
|
|
@ -78,6 +78,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
|
||||||
return c.Conn().BeginTx(ctx, txOptions)
|
return c.Conn().BeginTx(ctx, txOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
|
||||||
|
return c.Conn().BeginFunc(ctx, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
|
||||||
|
return c.Conn().BeginTxFunc(ctx, txOptions, f)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) Ping(ctx context.Context) error {
|
func (c *Conn) Ping(ctx context.Context) error {
|
||||||
return c.Conn().Ping(ctx)
|
return c.Conn().Ping(ctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -496,6 +496,20 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
|
||||||
return &Tx{t: t, c: c}, err
|
return &Tx{t: t, c: c}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
|
||||||
|
return p.BeginTxFunc(ctx, pgx.TxOptions{}, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
|
||||||
|
c, err := p.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer c.Release()
|
||||||
|
|
||||||
|
return c.BeginTxFunc(ctx, txOptions, f)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||||
c, err := p.Acquire(ctx)
|
c, err := p.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package pgxpool_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) {
|
||||||
|
|
||||||
assert.EqualValues(t, 0, db.Stat().TotalConns())
|
assert.EqualValues(t, 0, db.Stat().TotalConns())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
|
||||||
|
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
drop table if exists pgxpooltx;
|
||||||
|
create temporary table pgxpooltx(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err = db.Exec(context.Background(), createSql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
db.Exec(context.Background(), "drop table pgxpooltx")
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 3, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
|
||||||
|
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
drop table if exists pgxpooltx;
|
||||||
|
create temporary table pgxpooltx(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err = db.Exec(context.Background(), createSql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
db.Exec(context.Background(), "drop table pgxpooltx")
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return errors.New("do a rollback")
|
||||||
|
})
|
||||||
|
require.EqualError(t, err, "do a rollback")
|
||||||
|
|
||||||
|
_, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, n)
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) {
|
||||||
return tx.t.Begin(ctx)
|
return tx.t.Begin(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
|
||||||
|
return tx.t.BeginFunc(ctx, f)
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *Tx) Commit(ctx context.Context) error {
|
func (tx *Tx) Commit(ctx context.Context) error {
|
||||||
err := tx.t.Commit(ctx)
|
err := tx.t.Commit(ctx)
|
||||||
if tx.c != nil {
|
if tx.c != nil {
|
||||||
|
|
71
tx.go
71
tx.go
|
@ -85,6 +85,39 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
|
||||||
return &dbTx{conn: c}, nil
|
return &dbTx{conn: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns
|
||||||
|
// an error the transaction is rolled back. The context will be used when executing the transaction control statements
|
||||||
|
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f.
|
||||||
|
func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
|
||||||
|
return c.BeginTxFunc(ctx, TxOptions{}, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return
|
||||||
|
// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be
|
||||||
|
// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect
|
||||||
|
// the execution of f.
|
||||||
|
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
|
||||||
|
var tx Tx
|
||||||
|
tx, err = c.BeginTx(ctx, TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
rollbackErr := tx.Rollback(ctx)
|
||||||
|
if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
|
||||||
|
err = rollbackErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
fErr := f(tx)
|
||||||
|
if fErr != nil {
|
||||||
|
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
|
||||||
|
return fErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// Tx represents a database transaction.
|
// Tx represents a database transaction.
|
||||||
//
|
//
|
||||||
// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
|
// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
|
||||||
|
@ -96,6 +129,10 @@ type Tx interface {
|
||||||
// Begin starts a pseudo nested transaction.
|
// Begin starts a pseudo nested transaction.
|
||||||
Begin(ctx context.Context) (Tx, error)
|
Begin(ctx context.Context) (Tx, error)
|
||||||
|
|
||||||
|
// BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested
|
||||||
|
// transaction will be committed. If it does then it will be rolled back.
|
||||||
|
BeginFunc(ctx context.Context, f func(Tx) error) (err error)
|
||||||
|
|
||||||
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
|
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
|
||||||
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
|
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
|
||||||
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
|
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
|
||||||
|
@ -149,6 +186,32 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
|
||||||
return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil
|
return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
|
||||||
|
if tx.closed {
|
||||||
|
return ErrTxClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
var savepoint Tx
|
||||||
|
savepoint, err = tx.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
rollbackErr := savepoint.Rollback(ctx)
|
||||||
|
if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) {
|
||||||
|
err = rollbackErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
fErr := f(savepoint)
|
||||||
|
if fErr != nil {
|
||||||
|
_ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return
|
||||||
|
return fErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return savepoint.Commit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// Commit commits the transaction.
|
// Commit commits the transaction.
|
||||||
func (tx *dbTx) Commit(ctx context.Context) error {
|
func (tx *dbTx) Commit(ctx context.Context) error {
|
||||||
if tx.closed {
|
if tx.closed {
|
||||||
|
@ -273,6 +336,14 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) {
|
||||||
return sp.tx.Begin(ctx)
|
return sp.tx.Begin(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
|
||||||
|
if sp.closed {
|
||||||
|
return ErrTxClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return sp.tx.BeginFunc(ctx, f)
|
||||||
|
}
|
||||||
|
|
||||||
// Commit releases the savepoint essentially committing the pseudo nested transaction.
|
// Commit releases the savepoint essentially committing the pseudo nested transaction.
|
||||||
func (sp *dbSavepoint) Commit(ctx context.Context) error {
|
func (sp *dbSavepoint) Commit(ctx context.Context) error {
|
||||||
if sp.closed {
|
if sp.closed {
|
||||||
|
|
141
tx_test.go
141
tx_test.go
|
@ -2,6 +2,7 @@ package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -282,6 +283,64 @@ func TestBeginIsoLevels(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBeginFunc(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
create temporary table foo(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := conn.Exec(context.Background(), createSql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeginFuncRollbackOnError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
create temporary table foo(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := conn.Exec(context.Background(), createSql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return errors.New("some error")
|
||||||
|
})
|
||||||
|
require.EqualError(t, err, "some error")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 0, n)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBeginReadOnly(t *testing.T) {
|
func TestBeginReadOnly(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -433,3 +492,85 @@ func TestTxNestedTransactionRollback(t *testing.T) {
|
||||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, db)
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
create temporary table foo(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := db.Exec(context.Background(), createSql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 3, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, db)
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
create temporary table foo(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := db.Exec(context.Background(), createSql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||||
|
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return errors.New("do a rollback")
|
||||||
|
})
|
||||||
|
require.EqualError(t, err, "do a rollback")
|
||||||
|
|
||||||
|
_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, n)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue