mirror of https://github.com/jackc/pgx.git
parent
1f68908da6
commit
f8a5bc8273
8
tx.go
8
tx.go
|
@ -98,9 +98,8 @@ type Tx interface {
|
||||||
|
|
||||||
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
|
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
|
||||||
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
|
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
|
||||||
// times. If the commit fails with a rollback status (e.g. a deferred constraint was violated) then
|
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
|
||||||
// ErrTxCommitRollback will be returned. Any other failure of a real transaction will result in the connection being
|
// ErrTxCommitRollback will be returned.
|
||||||
// closed.
|
|
||||||
Commit(ctx context.Context) error
|
Commit(ctx context.Context) error
|
||||||
|
|
||||||
// Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a
|
// Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a
|
||||||
|
@ -158,9 +157,6 @@ func (tx *dbTx) Commit(ctx context.Context) error {
|
||||||
commandTag, err := tx.conn.Exec(ctx, "commit")
|
commandTag, err := tx.conn.Exec(ctx, "commit")
|
||||||
tx.closed = true
|
tx.closed = true
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// A commit failure leaves the connection in an undefined state so kill the connection (though any error that could
|
|
||||||
// cause this to fail should have already killed the connection)
|
|
||||||
tx.conn.die(errors.Errorf("commit failed: %w", err))
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if string(commandTag) == "ROLLBACK" {
|
if string(commandTag) == "ROLLBACK" {
|
||||||
|
|
48
tx_test.go
48
tx_test.go
|
@ -98,6 +98,51 @@ func TestTxCommitWhenTxBroken(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
create temporary table foo(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
if _, err := conn.Exec(context.Background(), createSql); err != nil {
|
||||||
|
t.Fatalf("Failed to create table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := conn.Begin(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("tx.Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("tx.Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit(context.Background())
|
||||||
|
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" {
|
||||||
|
t.Fatalf("Expected unique constraint violation 23505, got %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("QueryRow Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTxCommitSerializationFailure(t *testing.T) {
|
func TestTxCommitSerializationFailure(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -145,6 +190,9 @@ func TestTxCommitSerializationFailure(t *testing.T) {
|
||||||
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
|
if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
|
||||||
t.Fatalf("Expected serialization error 40001, got %#v", err)
|
t.Fatalf("Expected serialization error 40001, got %#v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, c1)
|
||||||
|
ensureConnValid(t, c2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransactionSuccessfulRollback(t *testing.T) {
|
func TestTransactionSuccessfulRollback(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue