diff --git a/tx.go b/tx.go index f5e2b02a..944560ef 100644 --- a/tx.go +++ b/tx.go @@ -98,9 +98,8 @@ type Tx interface { // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple - // times. If the commit fails with a rollback status (e.g. a deferred constraint was violated) then - // ErrTxCommitRollback will be returned. Any other failure of a real transaction will result in the connection being - // closed. + // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then + // ErrTxCommitRollback will be returned. Commit(ctx context.Context) error // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a @@ -158,9 +157,6 @@ func (tx *dbTx) Commit(ctx context.Context) error { commandTag, err := tx.conn.Exec(ctx, "commit") tx.closed = true if err != nil { - // A commit failure leaves the connection in an undefined state so kill the connection (though any error that could - // cause this to fail should have already killed the connection) - tx.conn.die(errors.Errorf("commit failed: %w", err)) return err } if string(commandTag) == "ROLLBACK" { diff --git a/tx_test.go b/tx_test.go index 7a640995..e0928c1b 100644 --- a/tx_test.go +++ b/tx_test.go @@ -98,6 +98,51 @@ func TestTxCommitWhenTxBroken(t *testing.T) { } } +func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(context.Background(), createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin(context.Background()) + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + err = tx.Commit(context.Background()) + if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" { + t.Fatalf("Expected unique constraint violation 23505, got %#v", err) + } + + var n int64 + err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + func TestTxCommitSerializationFailure(t *testing.T) { t.Parallel() @@ -145,6 +190,9 @@ func TestTxCommitSerializationFailure(t *testing.T) { if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { t.Fatalf("Expected serialization error 40001, got %#v", err) } + + ensureConnValid(t, c1) + ensureConnValid(t, c2) } func TestTransactionSuccessfulRollback(t *testing.T) {