mirror of https://github.com/jackc/pgx.git
Added *Connection.Transaction
parent
ad27e43224
commit
26105f4409
|
@ -415,6 +415,26 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Transaction(f func() bool) (committed bool, err error) {
|
||||||
|
if _, err = c.Execute("begin"); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if committed && c.txStatus == 'T' {
|
||||||
|
_, err = c.Execute("commit")
|
||||||
|
if err != nil {
|
||||||
|
committed = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err = c.Execute("rollback")
|
||||||
|
committed = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
committed = f()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Processes messages that are not exclusive to one context such as
|
// Processes messages that are not exclusive to one context such as
|
||||||
// authentication or query response. The response to these messages
|
// authentication or query response. The response to these messages
|
||||||
// is the same regardless of when they occur.
|
// is the same regardless of when they occur.
|
||||||
|
|
|
@ -445,7 +445,6 @@ func TestPrepareFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
|
||||||
if err = conn.Prepare("badSQL", "select foo"); err == nil {
|
if err = conn.Prepare("badSQL", "select foo"); err == nil {
|
||||||
t.Fatal("Prepare should have failed with syntax error")
|
t.Fatal("Prepare should have failed with syntax error")
|
||||||
}
|
}
|
||||||
|
@ -454,3 +453,145 @@ func TestPrepareFailure(t *testing.T) {
|
||||||
t.Fatalf("Prepare failure appears to have broken connection: %v", err)
|
t.Fatalf("Prepare failure appears to have broken connection: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransaction(t *testing.T) {
|
||||||
|
conn, err := Connect(*defaultConnectionParameters)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
createSql := `
|
||||||
|
create temporary table foo(
|
||||||
|
id integer,
|
||||||
|
unique (id) initially deferred
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
if _, err := conn.Execute(createSql); err != nil {
|
||||||
|
t.Fatalf("Failed to create table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var committed bool
|
||||||
|
|
||||||
|
// Transaction happy path -- it executes function and commits
|
||||||
|
committed, err = conn.Transaction(func() bool {
|
||||||
|
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("Failed to insert into table: %v", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Transaction unexpectedly failed: ", err)
|
||||||
|
}
|
||||||
|
if !committed {
|
||||||
|
t.Fatal("Transaction was not committed")
|
||||||
|
}
|
||||||
|
|
||||||
|
var n interface{}
|
||||||
|
n, err = conn.SelectValue("select count(*) from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||||
|
}
|
||||||
|
if n.(int64) != 1 {
|
||||||
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Execute("truncate foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error truncating foo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It rolls back when passed function returns false
|
||||||
|
committed, err = conn.Transaction(func() bool {
|
||||||
|
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("Failed to insert into table: %v", err)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Transaction unexpectedly failed: ", err)
|
||||||
|
}
|
||||||
|
if committed {
|
||||||
|
t.Fatal("Transaction should not have been committed")
|
||||||
|
}
|
||||||
|
n, err = conn.SelectValue("select count(*) from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||||
|
}
|
||||||
|
if n.(int64) != 0 {
|
||||||
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// it rolls back changes when connection is in error state
|
||||||
|
committed, err = conn.Transaction(func() bool {
|
||||||
|
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("Failed to insert into table: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := conn.Execute("invalid"); err == nil {
|
||||||
|
t.Fatal("Execute was supposed to error but didn't")
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Transaction unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if committed {
|
||||||
|
t.Fatal("Transaction was committed when it shouldn't have been")
|
||||||
|
}
|
||||||
|
n, err = conn.SelectValue("select count(*) from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||||
|
}
|
||||||
|
if n.(int64) != 0 {
|
||||||
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// when commit fails
|
||||||
|
committed, err = conn.Transaction(func() bool {
|
||||||
|
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("Failed to insert into table: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("Failed to insert into table: %v", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Transaction should have failed but didn't")
|
||||||
|
}
|
||||||
|
if committed {
|
||||||
|
t.Fatal("Transaction was committed when it should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = conn.SelectValue("select count(*) from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||||
|
}
|
||||||
|
if n.(int64) != 0 {
|
||||||
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// when something in transaction panicks
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
recover()
|
||||||
|
}()
|
||||||
|
|
||||||
|
committed, err = conn.Transaction(func() bool {
|
||||||
|
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
|
||||||
|
t.Fatalf("Failed to insert into table: %v", err)
|
||||||
|
}
|
||||||
|
panic("stop!")
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
n, err = conn.SelectValue("select count(*) from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error selecting value from foo: %v", err)
|
||||||
|
}
|
||||||
|
if n.(int64) != 0 {
|
||||||
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue