Do not kill connection on transaction commit failure

fixes #780
pull/783/head
Jack Christensen 2020-06-27 12:10:33 -05:00
parent 1f68908da6
commit f8a5bc8273
2 changed files with 50 additions and 6 deletions

8
tx.go
View File

@ -98,9 +98,8 @@ type Tx interface {
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
// times. If the commit fails with a rollback status (e.g. a deferred constraint was violated) then // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
// ErrTxCommitRollback will be returned. Any other failure of a real transaction will result in the connection being // ErrTxCommitRollback will be returned.
// closed.
Commit(ctx context.Context) error Commit(ctx context.Context) error
// Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a
@ -158,9 +157,6 @@ func (tx *dbTx) Commit(ctx context.Context) error {
commandTag, err := tx.conn.Exec(ctx, "commit") commandTag, err := tx.conn.Exec(ctx, "commit")
tx.closed = true tx.closed = true
if err != nil { if err != nil {
// A commit failure leaves the connection in an undefined state so kill the connection (though any error that could
// cause this to fail should have already killed the connection)
tx.conn.die(errors.Errorf("commit failed: %w", err))
return err return err
} }
if string(commandTag) == "ROLLBACK" { if string(commandTag) == "ROLLBACK" {

View File

@ -98,6 +98,51 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
} }
} }
func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
if _, err := conn.Exec(context.Background(), createSql); err != nil {
t.Fatalf("Failed to create table: %v", err)
}
tx, err := conn.Begin(context.Background())
if err != nil {
t.Fatalf("conn.Begin failed: %v", err)
}
if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
t.Fatalf("tx.Exec failed: %v", err)
}
if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
t.Fatalf("tx.Exec failed: %v", err)
}
err = tx.Commit(context.Background())
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" {
t.Fatalf("Expected unique constraint violation 23505, got %#v", err)
}
var n int64
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
if err != nil {
t.Fatalf("QueryRow Scan failed: %v", err)
}
if n != 0 {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
}
func TestTxCommitSerializationFailure(t *testing.T) { func TestTxCommitSerializationFailure(t *testing.T) {
t.Parallel() t.Parallel()
@ -145,6 +190,9 @@ func TestTxCommitSerializationFailure(t *testing.T) {
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
t.Fatalf("Expected serialization error 40001, got %#v", err) t.Fatalf("Expected serialization error 40001, got %#v", err)
} }
ensureConnValid(t, c1)
ensureConnValid(t, c2)
} }
func TestTransactionSuccessfulRollback(t *testing.T) { func TestTransactionSuccessfulRollback(t *testing.T) {