mirror of https://github.com/jackc/pgx.git
Added *Connection.Transaction
parent
ad27e43224
commit
26105f4409
|
@ -415,6 +415,26 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Transaction(f func() bool) (committed bool, err error) {
|
||||
if _, err = c.Execute("begin"); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if committed && c.txStatus == 'T' {
|
||||
_, err = c.Execute("commit")
|
||||
if err != nil {
|
||||
committed = false
|
||||
}
|
||||
} else {
|
||||
_, err = c.Execute("rollback")
|
||||
committed = false
|
||||
}
|
||||
}()
|
||||
|
||||
committed = f()
|
||||
return
|
||||
}
|
||||
|
||||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages
|
||||
// is the same regardless of when they occur.
|
||||
|
|
|
@ -445,7 +445,6 @@ func TestPrepareFailure(t *testing.T) {
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
|
||||
if err = conn.Prepare("badSQL", "select foo"); err == nil {
|
||||
t.Fatal("Prepare should have failed with syntax error")
|
||||
}
|
||||
|
@ -454,3 +453,145 @@ func TestPrepareFailure(t *testing.T) {
|
|||
t.Fatalf("Prepare failure appears to have broken connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
conn, err := Connect(*defaultConnectionParameters)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
createSql := `
|
||||
create temporary table foo(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
if _, err := conn.Execute(createSql); err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
var committed bool
|
||||
|
||||
// Transaction happy path -- it executes function and commits
|
||||
committed, err = conn.Transaction(func() bool {
|
||||
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("Failed to insert into table: %v", err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Transaction unexpectedly failed: ", err)
|
||||
}
|
||||
if !committed {
|
||||
t.Fatal("Transaction was not committed")
|
||||
}
|
||||
|
||||
var n interface{}
|
||||
n, err = conn.SelectValue("select count(*) from foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||
}
|
||||
if n.(int64) != 1 {
|
||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||
}
|
||||
|
||||
_, err = conn.Execute("truncate foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error truncating foo: %v", err)
|
||||
}
|
||||
|
||||
// It rolls back when passed function returns false
|
||||
committed, err = conn.Transaction(func() bool {
|
||||
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("Failed to insert into table: %v", err)
|
||||
}
|
||||
return false
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Transaction unexpectedly failed: ", err)
|
||||
}
|
||||
if committed {
|
||||
t.Fatal("Transaction should not have been committed")
|
||||
}
|
||||
n, err = conn.SelectValue("select count(*) from foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||
}
|
||||
if n.(int64) != 0 {
|
||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||
}
|
||||
|
||||
// it rolls back changes when connection is in error state
|
||||
committed, err = conn.Transaction(func() bool {
|
||||
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("Failed to insert into table: %v", err)
|
||||
}
|
||||
if _, err := conn.Execute("invalid"); err == nil {
|
||||
t.Fatal("Execute was supposed to error but didn't")
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Transaction unexpectedly failed: %v", err)
|
||||
}
|
||||
if committed {
|
||||
t.Fatal("Transaction was committed when it shouldn't have been")
|
||||
}
|
||||
n, err = conn.SelectValue("select count(*) from foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||
}
|
||||
if n.(int64) != 0 {
|
||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||
}
|
||||
|
||||
// when commit fails
|
||||
committed, err = conn.Transaction(func() bool {
|
||||
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("Failed to insert into table: %v", err)
|
||||
}
|
||||
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("Failed to insert into table: %v", err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Transaction should have failed but didn't")
|
||||
}
|
||||
if committed {
|
||||
t.Fatal("Transaction was committed when it should have failed")
|
||||
}
|
||||
|
||||
n, err = conn.SelectValue("select count(*) from foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||
}
|
||||
if n.(int64) != 0 {
|
||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||
}
|
||||
|
||||
// when something in transaction panicks
|
||||
func() {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
|
||||
committed, err = conn.Transaction(func() bool {
|
||||
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||
t.Fatalf("Failed to insert into table: %v", err)
|
||||
}
|
||||
panic("stop!")
|
||||
return true
|
||||
})
|
||||
|
||||
n, err = conn.SelectValue("select count(*) from foo")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||
}
|
||||
if n.(int64) != 0 {
|
||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue