Tx.Status handles in transaction error

refs #421
issue421
Jack Christensen 2018-05-12 10:23:39 -05:00
parent 26f6ae2c86
commit 6f1c5cc3e6
2 changed files with 51 additions and 2 deletions

20
tx.go
View File

@ -39,6 +39,7 @@ const (
TxStatusInProgress = 0 TxStatusInProgress = 0
TxStatusCommitFailure = -1 TxStatusCommitFailure = -1
TxStatusRollbackFailure = -2 TxStatusRollbackFailure = -2
TxStatusInFailure = -3
TxStatusCommitSuccess = 1 TxStatusCommitSuccess = 1
TxStatusRollbackSuccess = 2 TxStatusRollbackSuccess = 2
) )
@ -70,6 +71,7 @@ func (txOptions *TxOptions) beginSQL() string {
} }
var ErrTxClosed = errors.New("tx is closed") var ErrTxClosed = errors.New("tx is closed")
var ErrTxInFailure = errors.New("tx failed")
// ErrTxCommitRollback occurs when an error has occurred in a transaction and // ErrTxCommitRollback occurs when an error has occurred in a transaction and
// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but // Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but
@ -115,7 +117,7 @@ func (tx *Tx) Commit() error {
// CommitEx commits the transaction with a context. // CommitEx commits the transaction with a context.
func (tx *Tx) CommitEx(ctx context.Context) error { func (tx *Tx) CommitEx(ctx context.Context) error {
if tx.status != TxStatusInProgress { if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) {
return ErrTxClosed return ErrTxClosed
} }
@ -150,7 +152,7 @@ func (tx *Tx) Rollback() error {
// RollbackEx is the context version of Rollback // RollbackEx is the context version of Rollback
func (tx *Tx) RollbackEx(ctx context.Context) error { func (tx *Tx) RollbackEx(ctx context.Context) error {
if tx.status != TxStatusInProgress { if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) {
return ErrTxClosed return ErrTxClosed
} }
@ -177,6 +179,9 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag,
// ExecEx delegates to the underlying *Conn // ExecEx delegates to the underlying *Conn
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
if tx.status == TxStatusInFailure {
return CommandTag(""), ErrTxInFailure
}
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return CommandTag(""), ErrTxClosed return CommandTag(""), ErrTxClosed
} }
@ -191,6 +196,9 @@ func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) {
// PrepareEx delegates to the underlying *Conn // PrepareEx delegates to the underlying *Conn
func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
if tx.status == TxStatusInFailure {
return nil, ErrTxInFailure
}
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return nil, ErrTxClosed return nil, ErrTxClosed
} }
@ -205,6 +213,11 @@ func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) {
// QueryEx delegates to the underlying *Conn // QueryEx delegates to the underlying *Conn
func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
if tx.status == TxStatusInFailure {
// Because checking for errors can be deferred to the *Rows, build one with the error
err := ErrTxInFailure
return &Rows{closed: true, err: err}, err
}
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
// Because checking for errors can be deferred to the *Rows, build one with the error // Because checking for errors can be deferred to the *Rows, build one with the error
err := ErrTxClosed err := ErrTxClosed
@ -238,6 +251,9 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
// Status returns the status of the transaction from the set of // Status returns the status of the transaction from the set of
// pgx.TxStatus* constants. // pgx.TxStatus* constants.
func (tx *Tx) Status() int8 { func (tx *Tx) Status() int8 {
if tx.status == TxStatusInProgress && tx.conn.txStatus == 'E' {
tx.status = TxStatusInFailure
}
return tx.status return tx.status
} }

View File

@ -354,6 +354,39 @@ func TestTxStatus(t *testing.T) {
} }
} }
func TestTxStatusErrorInTransactions(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tx, err := conn.Begin()
if err != nil {
t.Fatal(err)
}
if status := tx.Status(); status != pgx.TxStatusInProgress {
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
}
_, err = tx.Exec("syntax error")
if err == nil {
t.Fatal("expected an error but did not get one")
}
if status := tx.Status(); status != pgx.TxStatusInFailure {
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
}
if err := tx.Rollback(); err != nil {
t.Fatal(err)
}
if status := tx.Status(); status != pgx.TxStatusRollbackSuccess {
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
}
}
func TestTxErr(t *testing.T) { func TestTxErr(t *testing.T) {
t.Parallel() t.Parallel()