From 99e546152212377b211ac6fdd9b2e5a59ebaf616 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 17 Aug 2019 17:22:14 -0500 Subject: [PATCH] Add pgx.Tx interface and pseudo nested transaction support This complicates the idea of a persistent transaction status and error so that concept was removed. --- large_objects.go | 9 +- pgxpool/conn.go | 4 +- pgxpool/pool.go | 4 +- pgxpool/tx.go | 44 ++++++---- stdlib/sql.go | 2 +- tx.go | 224 ++++++++++++++++++++++++++++++++++------------- tx_test.go | 131 +++++++++++++++------------ 7 files changed, 271 insertions(+), 147 deletions(-) diff --git a/large_objects.go b/large_objects.go index 2911c122..f58577fa 100644 --- a/large_objects.go +++ b/large_objects.go @@ -13,12 +13,7 @@ import ( // // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html type LargeObjects struct { - tx *Tx -} - -// LargeObjects returns a LargeObjects instance for the transaction. -func (tx *Tx) LargeObjects() LargeObjects { - return LargeObjects{tx: tx} + tx Tx } type LargeObjectMode int32 @@ -84,7 +79,7 @@ func (o *LargeObjects) Unlink(ctx context.Context, oid pgtype.OID) error { // io.Closer type LargeObject struct { ctx context.Context - tx *Tx + tx Tx fd int32 } diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 93d77044..ed663193 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -66,11 +66,11 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNam return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) } -func (c *Conn) Begin(ctx context.Context) (*pgx.Tx, error) { +func (c *Conn) Begin(ctx context.Context) (pgx.Tx, error) { return c.Conn().Begin(ctx) } -func (c *Conn) BeginEx(ctx context.Context, txOptions pgx.TxOptions) (*pgx.Tx, error) { +func (c *Conn) BeginEx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { return c.Conn().BeginEx(ctx, txOptions) } diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 6ba12568..43020fd2 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -352,10 +352,10 @@ func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return &poolBatchResults{br: br, c: c} } -func (p *Pool) Begin(ctx context.Context) (*Tx, error) { +func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) { return p.BeginEx(ctx, pgx.TxOptions{}) } -func (p *Pool) BeginEx(ctx context.Context, txOptions pgx.TxOptions) (*Tx, error) { +func (p *Pool) BeginEx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { c, err := p.Acquire(ctx) if err != nil { return nil, err diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 7d6d20f7..c9f00290 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -8,10 +8,14 @@ import ( ) type Tx struct { - t *pgx.Tx + t pgx.Tx c *Conn } +func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { + return tx.t.Begin(ctx) +} + func (tx *Tx) Commit(ctx context.Context) error { err := tx.t.Commit(ctx) if tx.c != nil { @@ -30,26 +34,30 @@ func (tx *Tx) Rollback(ctx context.Context) error { return err } -func (tx *Tx) Err() error { - return tx.t.Err() -} - -func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - return tx.c.Exec(ctx, sql, arguments...) -} - -func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { - return tx.c.Query(ctx, sql, args...) -} - -func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { - return tx.c.QueryRow(ctx, sql, args...) +func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return tx.t.CopyFrom(ctx, tableName, columnNames, rowSrc) } func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { - return tx.c.SendBatch(ctx, b) + return tx.t.SendBatch(ctx, b) } -func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { - return tx.c.CopyFrom(ctx, tableName, columnNames, rowSrc) +func (tx *Tx) LargeObjects() pgx.LargeObjects { + return tx.t.LargeObjects() +} + +func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgx.PreparedStatement, error) { + return tx.t.Prepare(ctx, name, sql) +} + +func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + return tx.t.Exec(ctx, sql, arguments...) +} + +func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + return tx.t.Query(ctx, sql, args...) +} + +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + return tx.t.QueryRow(ctx, sql, args...) } diff --git a/stdlib/sql.go b/stdlib/sql.go index 5f7c2690..701d6b2e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -484,7 +484,7 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} { return args } -type wrapTx struct{ tx *pgx.Tx } +type wrapTx struct{ tx pgx.Tx } func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) } diff --git a/tx.go b/tx.go index 6f4a5eac..cecd8f2c 100644 --- a/tx.go +++ b/tx.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "strconv" "github.com/jackc/pgconn" errors "golang.org/x/xerrors" @@ -35,15 +36,6 @@ const ( NotDeferrable = TxDeferrableMode("not deferrable") ) -const ( - TxStatusInProgress = 0 - TxStatusCommitFailure = -1 - TxStatusRollbackFailure = -2 - TxStatusInFailure = -3 - TxStatusCommitSuccess = 1 - TxStatusRollbackSuccess = 2 -) - type TxOptions struct { IsoLevel TxIsoLevel AccessMode TxAccessMode @@ -75,13 +67,13 @@ var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") // Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no // auto-rollback on context cancelation. -func (c *Conn) Begin(ctx context.Context) (*Tx, error) { +func (c *Conn) Begin(ctx context.Context) (*dbTx, error) { return c.BeginEx(ctx, TxOptions{}) } // BeginEx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only // affects the begin command. i.e. there is no auto-rollback on context cancelation. -func (c *Conn) BeginEx(ctx context.Context, txOptions TxOptions) (*Tx, error) { +func (c *Conn) BeginEx(ctx context.Context, txOptions TxOptions) (*dbTx, error) { _, err := c.Exec(ctx, txOptions.beginSQL()) if err != nil { // begin should never fail unless there is an underlying connection issue or @@ -90,70 +82,100 @@ func (c *Conn) BeginEx(ctx context.Context, txOptions TxOptions) (*Tx, error) { return nil, err } - return &Tx{conn: c}, nil + return &dbTx{conn: c}, nil } -// Tx represents a database transaction. +type Tx interface { + // Begin starts a pseudo nested transaction + Begin(ctx context.Context) (Tx, error) + Commit(ctx context.Context) error + Rollback(ctx context.Context) error + + CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) + SendBatch(ctx context.Context, b *Batch) BatchResults + LargeObjects() LargeObjects + + Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) + + Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) + QueryRow(ctx context.Context, sql string, args ...interface{}) Row +} + +// dbTx represents a database transaction. // -// All Tx methods return ErrTxClosed if Commit or Rollback has already been -// called on the Tx. -type Tx struct { - conn *Conn - err error - status int8 +// All dbTx methods return ErrTxClosed if Commit or Rollback has already been +// called on the dbTx. +type dbTx struct { + conn *Conn + err error + savepointNum int64 + closed bool +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { + if tx.closed { + return nil, ErrTxClosed + } + + tx.savepointNum += 1 + _, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10)) + if err != nil { + return nil, err + } + + return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil } // Commit commits the transaction. -func (tx *Tx) Commit(ctx context.Context) error { - if tx.status != TxStatusInProgress { +func (tx *dbTx) Commit(ctx context.Context) error { + if tx.closed { return ErrTxClosed } commandTag, err := tx.conn.Exec(ctx, "commit") - if err == nil && string(commandTag) == "COMMIT" { - tx.status = TxStatusCommitSuccess - } else if err == nil && string(commandTag) == "ROLLBACK" { - tx.status = TxStatusCommitFailure - tx.err = ErrTxCommitRollback - } else { - tx.status = TxStatusCommitFailure - tx.err = err - // A commit failure leaves the connection in an undefined state - tx.conn.die(errors.New("commit failed")) + tx.closed = true + if err != nil { + // A commit failure leaves the connection in an undefined state so kill the connection (though any error that could + // cause this to fail should have already killed the connection) + tx.conn.die(errors.Errorf("commit failed: %w", err)) + return err + } + if string(commandTag) == "ROLLBACK" { + return ErrTxCommitRollback } - return tx.err + return nil } // Rollback rolls back the transaction. Rollback will return ErrTxClosed if the // Tx is already closed, but is otherwise safe to call multiple times. Hence, a // defer tx.Rollback() is safe even if tx.Commit() will be called first in a // non-error condition. -func (tx *Tx) Rollback(ctx context.Context) error { - if tx.status != TxStatusInProgress { +func (tx *dbTx) Rollback(ctx context.Context) error { + if tx.closed { return ErrTxClosed } - _, tx.err = tx.conn.Exec(ctx, "rollback") - if tx.err == nil { - tx.status = TxStatusRollbackSuccess - } else { - tx.status = TxStatusRollbackFailure + _, err := tx.conn.Exec(ctx, "rollback") + if err != nil { // A rollback failure leaves the connection in an undefined state - tx.conn.die(errors.New("rollback failed")) + tx.conn.die(errors.Errorf("rollback failed: %w", err)) + return err } - return tx.err + return nil } // Exec delegates to the underlying *Conn -func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { +func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { return tx.conn.Exec(ctx, sql, arguments...) } // Prepare delegates to the underlying *Conn -func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) { - if tx.status != TxStatusInProgress { +func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) { + if tx.closed { return nil, ErrTxClosed } @@ -161,8 +183,8 @@ func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*PreparedStatement } // Query delegates to the underlying *Conn -func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { - if tx.status != TxStatusInProgress { +func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed return &connRows{closed: true, err: err}, err @@ -172,14 +194,14 @@ func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, } // QueryRow delegates to the underlying *Conn -func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { +func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { rows, _ := tx.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // CopyFrom delegates to the underlying *Conn -func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { - if tx.status != TxStatusInProgress { +func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + if tx.closed { return 0, ErrTxClosed } @@ -187,24 +209,104 @@ func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames [] } // SendBatch delegates to the underlying *Conn -func (tx *Tx) SendBatch(ctx context.Context, b *Batch) BatchResults { - if tx.status != TxStatusInProgress { +func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults { + if tx.closed { return &batchResults{err: ErrTxClosed} } return tx.conn.SendBatch(ctx, b) } -// Status returns the status of the transaction from the set of -// pgx.TxStatus* constants. -func (tx *Tx) Status() int8 { - if tx.status == TxStatusInProgress && tx.conn.pgConn.TxStatus == 'E' { - return TxStatusInFailure - } - return tx.status +// LargeObjects returns a LargeObjects instance for the transaction. +func (tx *dbTx) LargeObjects() LargeObjects { + return LargeObjects{tx: tx} } -// Err returns the final error state, if any, of calling Commit or Rollback. -func (tx *Tx) Err() error { - return tx.err +// dbSavepoint represents a nested transaction implemented by a savepoint. +type dbSavepoint struct { + tx Tx + savepointNum int64 + closed bool +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.Begin(ctx) +} + +// Commit releases the savepoint essentially committing the pseudo nested transaction. +func (sp *dbSavepoint) Commit(ctx context.Context) error { + _, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) + sp.closed = true + return err +} + +// Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return +// ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback() +// is safe even if sp.Commit() will be called first in a non-error condition. +func (sp *dbSavepoint) Rollback(ctx context.Context) error { + _, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) + sp.closed = true + return err +} + +// Exec delegates to the underlying Tx +func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.Exec(ctx, sql, arguments...) +} + +// Prepare delegates to the underlying Tx +func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.Prepare(ctx, name, sql) +} + +// Query delegates to the underlying Tx +func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + if sp.closed { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxClosed + return &connRows{closed: true, err: err}, err + } + + return sp.tx.Query(ctx, sql, args...) +} + +// QueryRow delegates to the underlying Tx +func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := sp.Query(ctx, sql, args...) + return (*connRow)(rows.(*connRows)) +} + +// CopyFrom delegates to the underlying *Conn +func (sp *dbSavepoint) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + if sp.closed { + return 0, ErrTxClosed + } + + return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// SendBatch delegates to the underlying *Conn +func (sp *dbSavepoint) SendBatch(ctx context.Context, b *Batch) BatchResults { + if sp.closed { + return &batchResults{err: ErrTxClosed} + } + + return sp.tx.SendBatch(ctx, b) +} + +func (sp *dbSavepoint) LargeObjects() LargeObjects { + return LargeObjects{tx: sp} } diff --git a/tx_test.go b/tx_test.go index 93edb048..841c7909 100644 --- a/tx_test.go +++ b/tx_test.go @@ -232,102 +232,121 @@ func TestBeginReadOnly(t *testing.T) { } } -func TestTxStatus(t *testing.T) { +func TestTxNestedTransactionCommit(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(context.Background(), createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } - if status := tx.Status(); status != pgx.TxStatusInProgress { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) } - if err := tx.Rollback(context.Background()); err != nil { + nestedTx, err := tx.Begin(context.Background()) + if err != nil { t.Fatal(err) } - if status := tx.Status(); status != pgx.TxStatusRollbackSuccess { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) + _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") + if err != nil { + t.Fatalf("nestedTx.Exec failed: %v", err) + } + + err = nestedTx.Commit(context.Background()) + if err != nil { + t.Fatalf("nestedTx.Commit failed: %v", err) + } + + err = tx.Commit(context.Background()) + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 2 { + t.Fatalf("Did not receive correct number of rows: %v", n) } } -func TestTxStatusErrorInTransactions(t *testing.T) { +func TestTxNestedTransactionRollback(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(context.Background(), createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } - if status := tx.Status(); status != pgx.TxStatusInProgress { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) } - _, err = tx.Exec(context.Background(), "savepoint s") + nestedTx, err := tx.Begin(context.Background()) if err != nil { t.Fatal(err) } - _, err = tx.Exec(context.Background(), "syntax error") - if err == nil { - t.Fatal("expected an error but did not get one") - } - - if status := tx.Status(); status != pgx.TxStatusInFailure { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status) - } - - _, err = tx.Exec(context.Background(), "rollback to s") + _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") if err != nil { - t.Fatal(err) + t.Fatalf("nestedTx.Exec failed: %v", err) } - if status := tx.Status(); status != pgx.TxStatusInProgress { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) + err = nestedTx.Rollback(context.Background()) + if err != nil { + t.Fatalf("nestedTx.Rollback failed: %v", err) } - if err := tx.Rollback(context.Background()); err != nil { - t.Fatal(err) + _, err = tx.Exec(context.Background(), "insert into foo(id) values (3)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) } - if status := tx.Status(); status != pgx.TxStatusRollbackSuccess { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) - } -} - -func TestTxErr(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tx, err := conn.Begin(context.Background()) - if err != nil { - t.Fatal(err) - } - - // Purposely break transaction - if _, err := tx.Exec(context.Background(), "syntax error"); err == nil { - t.Fatal("Unexpected success") - } - - if err := tx.Commit(context.Background()); err != pgx.ErrTxCommitRollback { - t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) - } - - if status := tx.Status(); status != pgx.TxStatusCommitFailure { - t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) - } - - if err := tx.Err(); err != pgx.ErrTxCommitRollback { - t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) + err = tx.Commit(context.Background()) + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 2 { + t.Fatalf("Did not receive correct number of rows: %v", n) } }