mirror of https://github.com/jackc/pgx.git
parent
26f6ae2c86
commit
6f1c5cc3e6
20
tx.go
20
tx.go
|
@ -39,6 +39,7 @@ const (
|
||||||
TxStatusInProgress = 0
|
TxStatusInProgress = 0
|
||||||
TxStatusCommitFailure = -1
|
TxStatusCommitFailure = -1
|
||||||
TxStatusRollbackFailure = -2
|
TxStatusRollbackFailure = -2
|
||||||
|
TxStatusInFailure = -3
|
||||||
TxStatusCommitSuccess = 1
|
TxStatusCommitSuccess = 1
|
||||||
TxStatusRollbackSuccess = 2
|
TxStatusRollbackSuccess = 2
|
||||||
)
|
)
|
||||||
|
@ -70,6 +71,7 @@ func (txOptions *TxOptions) beginSQL() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrTxClosed = errors.New("tx is closed")
|
var ErrTxClosed = errors.New("tx is closed")
|
||||||
|
var ErrTxInFailure = errors.New("tx failed")
|
||||||
|
|
||||||
// ErrTxCommitRollback occurs when an error has occurred in a transaction and
|
// ErrTxCommitRollback occurs when an error has occurred in a transaction and
|
||||||
// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but
|
// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but
|
||||||
|
@ -115,7 +117,7 @@ func (tx *Tx) Commit() error {
|
||||||
|
|
||||||
// CommitEx commits the transaction with a context.
|
// CommitEx commits the transaction with a context.
|
||||||
func (tx *Tx) CommitEx(ctx context.Context) error {
|
func (tx *Tx) CommitEx(ctx context.Context) error {
|
||||||
if tx.status != TxStatusInProgress {
|
if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) {
|
||||||
return ErrTxClosed
|
return ErrTxClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,7 +152,7 @@ func (tx *Tx) Rollback() error {
|
||||||
|
|
||||||
// RollbackEx is the context version of Rollback
|
// RollbackEx is the context version of Rollback
|
||||||
func (tx *Tx) RollbackEx(ctx context.Context) error {
|
func (tx *Tx) RollbackEx(ctx context.Context) error {
|
||||||
if tx.status != TxStatusInProgress {
|
if !(tx.status == TxStatusInProgress || tx.status == TxStatusInFailure) {
|
||||||
return ErrTxClosed
|
return ErrTxClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,6 +179,9 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag,
|
||||||
|
|
||||||
// ExecEx delegates to the underlying *Conn
|
// ExecEx delegates to the underlying *Conn
|
||||||
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
|
if tx.status == TxStatusInFailure {
|
||||||
|
return CommandTag(""), ErrTxInFailure
|
||||||
|
}
|
||||||
if tx.status != TxStatusInProgress {
|
if tx.status != TxStatusInProgress {
|
||||||
return CommandTag(""), ErrTxClosed
|
return CommandTag(""), ErrTxClosed
|
||||||
}
|
}
|
||||||
|
@ -191,6 +196,9 @@ func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) {
|
||||||
|
|
||||||
// PrepareEx delegates to the underlying *Conn
|
// PrepareEx delegates to the underlying *Conn
|
||||||
func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
|
func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
|
||||||
|
if tx.status == TxStatusInFailure {
|
||||||
|
return nil, ErrTxInFailure
|
||||||
|
}
|
||||||
if tx.status != TxStatusInProgress {
|
if tx.status != TxStatusInProgress {
|
||||||
return nil, ErrTxClosed
|
return nil, ErrTxClosed
|
||||||
}
|
}
|
||||||
|
@ -205,6 +213,11 @@ func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||||
|
|
||||||
// QueryEx delegates to the underlying *Conn
|
// QueryEx delegates to the underlying *Conn
|
||||||
func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
|
func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
|
||||||
|
if tx.status == TxStatusInFailure {
|
||||||
|
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||||
|
err := ErrTxInFailure
|
||||||
|
return &Rows{closed: true, err: err}, err
|
||||||
|
}
|
||||||
if tx.status != TxStatusInProgress {
|
if tx.status != TxStatusInProgress {
|
||||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||||
err := ErrTxClosed
|
err := ErrTxClosed
|
||||||
|
@ -238,6 +251,9 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
|
||||||
// Status returns the status of the transaction from the set of
|
// Status returns the status of the transaction from the set of
|
||||||
// pgx.TxStatus* constants.
|
// pgx.TxStatus* constants.
|
||||||
func (tx *Tx) Status() int8 {
|
func (tx *Tx) Status() int8 {
|
||||||
|
if tx.status == TxStatusInProgress && tx.conn.txStatus == 'E' {
|
||||||
|
tx.status = TxStatusInFailure
|
||||||
|
}
|
||||||
return tx.status
|
return tx.status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
33
tx_test.go
33
tx_test.go
|
@ -354,6 +354,39 @@ func TestTxStatus(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTxStatusErrorInTransactions(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
tx, err := conn.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := tx.Status(); status != pgx.TxStatusInProgress {
|
||||||
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("syntax error")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error but did not get one")
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := tx.Status(); status != pgx.TxStatusInFailure {
|
||||||
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Rollback(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := tx.Status(); status != pgx.TxStatusRollbackSuccess {
|
||||||
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTxErr(t *testing.T) {
|
func TestTxErr(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue