Added *Connection.Transaction

pgx-vs-pq
Jack Christensen 2013-07-13 10:19:29 -04:00
parent ad27e43224
commit 26105f4409
2 changed files with 162 additions and 1 deletions

View File

@ -415,6 +415,26 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s
}
}
func (c *Connection) Transaction(f func() bool) (committed bool, err error) {
if _, err = c.Execute("begin"); err != nil {
return
}
defer func() {
if committed && c.txStatus == 'T' {
_, err = c.Execute("commit")
if err != nil {
committed = false
}
} else {
_, err = c.Execute("rollback")
committed = false
}
}()
committed = f()
return
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages
// is the same regardless of when they occur.

View File

@ -445,7 +445,6 @@ func TestPrepareFailure(t *testing.T) {
}
defer conn.Close()
if err = conn.Prepare("badSQL", "select foo"); err == nil {
t.Fatal("Prepare should have failed with syntax error")
}
@ -454,3 +453,145 @@ func TestPrepareFailure(t *testing.T) {
t.Fatalf("Prepare failure appears to have broken connection: %v", err)
}
}
func TestTransaction(t *testing.T) {
conn, err := Connect(*defaultConnectionParameters)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
defer conn.Close()
createSql := `
create temporary table foo(
id integer,
unique (id) initially deferred
);
`
if _, err := conn.Execute(createSql); err != nil {
t.Fatalf("Failed to create table: %v", err)
}
var committed bool
// Transaction happy path -- it executes function and commits
committed, err = conn.Transaction(func() bool {
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
t.Fatalf("Failed to insert into table: %v", err)
}
return true
})
if err != nil {
t.Fatalf("Transaction unexpectedly failed: ", err)
}
if !committed {
t.Fatal("Transaction was not committed")
}
var n interface{}
n, err = conn.SelectValue("select count(*) from foo")
if err != nil {
t.Fatalf("Unexpected error selecting value from foo: %v", err)
}
if n.(int64) != 1 {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
_, err = conn.Execute("truncate foo")
if err != nil {
t.Fatalf("Unexpected error truncating foo: %v", err)
}
// It rolls back when passed function returns false
committed, err = conn.Transaction(func() bool {
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
t.Fatalf("Failed to insert into table: %v", err)
}
return false
})
if err != nil {
t.Fatalf("Transaction unexpectedly failed: ", err)
}
if committed {
t.Fatal("Transaction should not have been committed")
}
n, err = conn.SelectValue("select count(*) from foo")
if err != nil {
t.Fatalf("Unexpected error selecting value from foo: %v", err)
}
if n.(int64) != 0 {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
// it rolls back changes when connection is in error state
committed, err = conn.Transaction(func() bool {
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
t.Fatalf("Failed to insert into table: %v", err)
}
if _, err := conn.Execute("invalid"); err == nil {
t.Fatal("Execute was supposed to error but didn't")
}
return true
})
if err != nil {
t.Fatalf("Transaction unexpectedly failed: %v", err)
}
if committed {
t.Fatal("Transaction was committed when it shouldn't have been")
}
n, err = conn.SelectValue("select count(*) from foo")
if err != nil {
t.Fatalf("Unexpected error selecting value from foo: %v", err)
}
if n.(int64) != 0 {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
// when commit fails
committed, err = conn.Transaction(func() bool {
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
t.Fatalf("Failed to insert into table: %v", err)
}
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
t.Fatalf("Failed to insert into table: %v", err)
}
return true
})
if err == nil {
t.Fatal("Transaction should have failed but didn't")
}
if committed {
t.Fatal("Transaction was committed when it should have failed")
}
n, err = conn.SelectValue("select count(*) from foo")
if err != nil {
t.Fatalf("Unexpected error selecting value from foo: %v", err)
}
if n.(int64) != 0 {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
// when something in transaction panicks
func() {
defer func() {
recover()
}()
committed, err = conn.Transaction(func() bool {
if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil {
t.Fatalf("Failed to insert into table: %v", err)
}
panic("stop!")
return true
})
n, err = conn.SelectValue("select count(*) from foo")
if err != nil {
t.Fatalf("Unexpected error selecting value from foo: %v", err)
}
if n.(int64) != 0 {
t.Fatalf("Did not receive correct number of rows: %v", n)
}
}()
}