diff --git a/tx.go b/tx.go index 81fcfa26..6561f49e 100644 --- a/tx.go +++ b/tx.go @@ -39,6 +39,7 @@ const ( TxStatusInProgress = 0 TxStatusCommitFailure = -1 TxStatusRollbackFailure = -2 + TxStatusInFailure = -3 TxStatusCommitSuccess = 1 TxStatusRollbackSuccess = 2 ) @@ -70,6 +71,7 @@ func (txOptions *TxOptions) beginSQL() string { } var ErrTxClosed = errors.New("tx is closed") +var ErrTxInFailure = errors.New("tx failed") // ErrTxCommitRollback occurs when an error has occurred in a transaction and // Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but @@ -115,7 +117,7 @@ func (tx *Tx) Commit() error { // CommitEx commits the transaction with a context. func (tx *Tx) CommitEx(ctx context.Context) error { - if tx.status != TxStatusInProgress { + if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) { return ErrTxClosed } @@ -150,7 +152,7 @@ func (tx *Tx) Rollback() error { // RollbackEx is the context version of Rollback func (tx *Tx) RollbackEx(ctx context.Context) error { - if tx.status != TxStatusInProgress { + if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) { return ErrTxClosed } @@ -177,6 +179,9 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, // ExecEx delegates to the underlying *Conn func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { + if tx.status == TxStatusInFailure { + return CommandTag(""), ErrTxInFailure + } if tx.status != TxStatusInProgress { return CommandTag(""), ErrTxClosed } @@ -191,6 +196,9 @@ func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) { // PrepareEx delegates to the underlying *Conn func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { + if tx.status == TxStatusInFailure { + return nil, ErrTxInFailure + } if tx.status != TxStatusInProgress { return nil, ErrTxClosed } @@ -205,6 +213,11 @@ func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { // QueryEx delegates to the underlying *Conn func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { + if tx.status == TxStatusInFailure { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxInFailure + return &Rows{closed: true, err: err}, err + } if tx.status != TxStatusInProgress { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed @@ -238,6 +251,9 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr // Status returns the status of the transaction from the set of // pgx.TxStatus* constants. func (tx *Tx) Status() int8 { + if tx.status == TxStatusInProgress && tx.conn.txStatus == 'E' { + tx.status = TxStatusInFailure + } return tx.status } diff --git a/tx_test.go b/tx_test.go index b25e1c9f..8c562b7e 100644 --- a/tx_test.go +++ b/tx_test.go @@ -354,6 +354,39 @@ func TestTxStatus(t *testing.T) { } } +func TestTxStatusErrorInTransactions(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + if status := tx.Status(); status != pgx.TxStatusInProgress { + t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) + } + + _, err = tx.Exec("syntax error") + if err == nil { + t.Fatal("expected an error but did not get one") + } + + if status := tx.Status(); status != pgx.TxStatusInFailure { + t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + if status := tx.Status(); status != pgx.TxStatusRollbackSuccess { + t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status) + } +} + func TestTxErr(t *testing.T) { t.Parallel()