From 26105f4409ddc09e739dfb153df9fc21f27112d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jul 2013 10:19:29 -0400 Subject: [PATCH] Added *Connection.Transaction --- connection.go | 20 +++++++ connection_test.go | 143 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 1 deletion(-) diff --git a/connection.go b/connection.go index addb4e7b..43801b31 100644 --- a/connection.go +++ b/connection.go @@ -415,6 +415,26 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s } } +func (c *Connection) Transaction(f func() bool) (committed bool, err error) { + if _, err = c.Execute("begin"); err != nil { + return + } + defer func() { + if committed && c.txStatus == 'T' { + _, err = c.Execute("commit") + if err != nil { + committed = false + } + } else { + _, err = c.Execute("rollback") + committed = false + } + }() + + committed = f() + return +} + // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. diff --git a/connection_test.go b/connection_test.go index 002f9e31..a79d3e2c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -445,7 +445,6 @@ func TestPrepareFailure(t *testing.T) { } defer conn.Close() - if err = conn.Prepare("badSQL", "select foo"); err == nil { t.Fatal("Prepare should have failed with syntax error") } @@ -454,3 +453,145 @@ func TestPrepareFailure(t *testing.T) { t.Fatalf("Prepare failure appears to have broken connection: %v", err) } } + +func TestTransaction(t *testing.T) { + conn, err := Connect(*defaultConnectionParameters) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer conn.Close() + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Execute(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + var committed bool + + // Transaction happy path -- it executes function and commits + committed, err = conn.Transaction(func() bool { + if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { + t.Fatalf("Failed to insert into table: %v", err) + } + return true + }) + if err != nil { + t.Fatalf("Transaction unexpectedly failed: ", err) + } + if !committed { + t.Fatal("Transaction was not committed") + } + + var n interface{} + n, err = conn.SelectValue("select count(*) from foo") + if err != nil { + t.Fatalf("Unexpected error selecting value from foo: %v", err) + } + if n.(int64) != 1 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + _, err = conn.Execute("truncate foo") + if err != nil { + t.Fatalf("Unexpected error truncating foo: %v", err) + } + + // It rolls back when passed function returns false + committed, err = conn.Transaction(func() bool { + if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { + t.Fatalf("Failed to insert into table: %v", err) + } + return false + }) + if err != nil { + t.Fatalf("Transaction unexpectedly failed: ", err) + } + if committed { + t.Fatal("Transaction should not have been committed") + } + n, err = conn.SelectValue("select count(*) from foo") + if err != nil { + t.Fatalf("Unexpected error selecting value from foo: %v", err) + } + if n.(int64) != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + // it rolls back changes when connection is in error state + committed, err = conn.Transaction(func() bool { + if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { + t.Fatalf("Failed to insert into table: %v", err) + } + if _, err := conn.Execute("invalid"); err == nil { + t.Fatal("Execute was supposed to error but didn't") + } + return true + }) + if err != nil { + t.Fatalf("Transaction unexpectedly failed: %v", err) + } + if committed { + t.Fatal("Transaction was committed when it shouldn't have been") + } + n, err = conn.SelectValue("select count(*) from foo") + if err != nil { + t.Fatalf("Unexpected error selecting value from foo: %v", err) + } + if n.(int64) != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + // when commit fails + committed, err = conn.Transaction(func() bool { + if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { + t.Fatalf("Failed to insert into table: %v", err) + } + if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { + t.Fatalf("Failed to insert into table: %v", err) + } + return true + }) + if err == nil { + t.Fatal("Transaction should have failed but didn't") + } + if committed { + t.Fatal("Transaction was committed when it should have failed") + } + + n, err = conn.SelectValue("select count(*) from foo") + if err != nil { + t.Fatalf("Unexpected error selecting value from foo: %v", err) + } + if n.(int64) != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + // when something in transaction panicks + func() { + defer func() { + recover() + }() + + committed, err = conn.Transaction(func() bool { + if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { + t.Fatalf("Failed to insert into table: %v", err) + } + panic("stop!") + return true + }) + + n, err = conn.SelectValue("select count(*) from foo") + if err != nil { + t.Fatalf("Unexpected error selecting value from foo: %v", err) + } + if n.(int64) != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + }() +}