mirror of https://github.com/jackc/pgx.git
parent
26f6ae2c86
commit
6f1c5cc3e6
20
tx.go
20
tx.go
|
@ -39,6 +39,7 @@ const (
|
|||
TxStatusInProgress = 0
|
||||
TxStatusCommitFailure = -1
|
||||
TxStatusRollbackFailure = -2
|
||||
TxStatusInFailure = -3
|
||||
TxStatusCommitSuccess = 1
|
||||
TxStatusRollbackSuccess = 2
|
||||
)
|
||||
|
@ -70,6 +71,7 @@ func (txOptions *TxOptions) beginSQL() string {
|
|||
}
|
||||
|
||||
var ErrTxClosed = errors.New("tx is closed")
|
||||
var ErrTxInFailure = errors.New("tx failed")
|
||||
|
||||
// ErrTxCommitRollback occurs when an error has occurred in a transaction and
|
||||
// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but
|
||||
|
@ -115,7 +117,7 @@ func (tx *Tx) Commit() error {
|
|||
|
||||
// CommitEx commits the transaction with a context.
|
||||
func (tx *Tx) CommitEx(ctx context.Context) error {
|
||||
if tx.status != TxStatusInProgress {
|
||||
if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) {
|
||||
return ErrTxClosed
|
||||
}
|
||||
|
||||
|
@ -150,7 +152,7 @@ func (tx *Tx) Rollback() error {
|
|||
|
||||
// RollbackEx is the context version of Rollback
|
||||
func (tx *Tx) RollbackEx(ctx context.Context) error {
|
||||
if tx.status != TxStatusInProgress {
|
||||
if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) {
|
||||
return ErrTxClosed
|
||||
}
|
||||
|
||||
|
@ -177,6 +179,9 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag,
|
|||
|
||||
// ExecEx delegates to the underlying *Conn
|
||||
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
if tx.status == TxStatusInFailure {
|
||||
return CommandTag(""), ErrTxInFailure
|
||||
}
|
||||
if tx.status != TxStatusInProgress {
|
||||
return CommandTag(""), ErrTxClosed
|
||||
}
|
||||
|
@ -191,6 +196,9 @@ func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) {
|
|||
|
||||
// PrepareEx delegates to the underlying *Conn
|
||||
func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
|
||||
if tx.status == TxStatusInFailure {
|
||||
return nil, ErrTxInFailure
|
||||
}
|
||||
if tx.status != TxStatusInProgress {
|
||||
return nil, ErrTxClosed
|
||||
}
|
||||
|
@ -205,6 +213,11 @@ func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||
|
||||
// QueryEx delegates to the underlying *Conn
|
||||
func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
|
||||
if tx.status == TxStatusInFailure {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
err := ErrTxInFailure
|
||||
return &Rows{closed: true, err: err}, err
|
||||
}
|
||||
if tx.status != TxStatusInProgress {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
err := ErrTxClosed
|
||||
|
@ -238,6 +251,9 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
|
|||
// Status returns the status of the transaction from the set of
|
||||
// pgx.TxStatus* constants.
|
||||
func (tx *Tx) Status() int8 {
|
||||
if tx.status == TxStatusInProgress && tx.conn.txStatus == 'E' {
|
||||
tx.status = TxStatusInFailure
|
||||
}
|
||||
return tx.status
|
||||
}
|
||||
|
||||
|
|
33
tx_test.go
33
tx_test.go
|
@ -354,6 +354,39 @@ func TestTxStatus(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTxStatusErrorInTransactions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if status := tx.Status(); status != pgx.TxStatusInProgress {
|
||||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("syntax error")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error but did not get one")
|
||||
}
|
||||
|
||||
if status := tx.Status(); status != pgx.TxStatusInFailure {
|
||||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
|
||||
}
|
||||
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if status := tx.Status(); status != pgx.TxStatusRollbackSuccess {
|
||||
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue