From ae6a87545bc09f0d8779286bc4c1e54614606817 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jul 2014 07:59:30 -0500 Subject: [PATCH] Use database/sql style transaction interface --- conn.go | 8 --- conn_pool.go | 30 ---------- conn_pool_test.go | 107 +++++++++------------------------- tx.go | 86 ++++++++++----------------- tx_test.go | 145 ---------------------------------------------- 5 files changed, 60 insertions(+), 316 deletions(-) diff --git a/conn.go b/conn.go index bec30527..54a5599a 100644 --- a/conn.go +++ b/conn.go @@ -23,14 +23,6 @@ import ( "time" ) -// Transaction isolation levels -const ( - Serializable = "serializable" - RepeatableRead = "repeatable read" - ReadCommitted = "read committed" - ReadUncommitted = "read uncommitted" -) - // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) diff --git a/conn_pool.go b/conn_pool.go index c13197d0..96119fad 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -232,33 +232,3 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) { tx.pool = p return tx, nil } - -// Transaction acquires a connection, delegates the call to that connection, -// and releases the connection. The call signature differs slightly from the -// underlying Transaction in that the callback function accepts a *Conn -func (p *ConnPool) Transaction(f func(conn *Conn) bool) (committed bool, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.Transaction(func() bool { - return f(c) - }) -} - -// TransactionIso acquires a connection, delegates the call to that connection, -// and releases the connection. The call signature differs slightly from the -// underlying TransactionIso in that the callback function accepts a *Conn -func (p *ConnPool) TransactionIso(isoLevel string, f func(conn *Conn) bool) (committed bool, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.TransactionIso(isoLevel, func() bool { - return f(c) - }) -} diff --git a/conn_pool_test.go b/conn_pool_test.go index cfc8b451..615faa03 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -309,113 +309,62 @@ func TestConnPoolTransaction(t *testing.T) { pool := createConnPool(t, 2) defer pool.Close() + stats := pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + tx, err := pool.Begin() if err != nil { t.Fatalf("pool.Begin failed: %v", err) } + defer tx.Rollback() var n int32 - err := pool.QueryRow("select 40+$1", 2).Scan(&n) + err = tx.QueryRow("select 40+$1", 2).Scan(&n) if err != nil { - t.Fatalf("pool.QueryRow Scan failed: %v", err) + t.Fatalf("tx.QueryRow Scan failed: %v", err) } - if n != 42 { t.Errorf("Expected 42, got %d", n) } - stats := pool.Stat() + stats = pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + + stats = pool.Stat() if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { t.Fatalf("Unexpected connection pool stats: %v", stats) } } -func TestPoolTransaction(t *testing.T) { +func TestConnPoolTransactionIso(t *testing.T) { t.Parallel() pool := createConnPool(t, 2) defer pool.Close() - committed, err := pool.Transaction(func(conn *pgx.Conn) bool { - mustExec(t, conn, "create temporary table foo(id serial primary key)") - return true - }) + tx, err := pool.BeginIso(pgx.Serializable) if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if !committed { - t.Fatal("Transaction was not committed when it should have been") + t.Fatalf("pool.Begin failed: %v", err) } + defer tx.Rollback() - committed, err = pool.Transaction(func(conn *pgx.Conn) bool { - var n int64 - err := conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow.Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive expected value: %v", n) - } - - mustExec(t, conn, "insert into foo(id) values(default)") - - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow.Scan failed: %v", err) - } - if n != 1 { - t.Fatalf("Did not receive expected value: %v", n) - } - - return false - }) + var level string + err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&level) if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if committed { - t.Fatal("Transaction was committed when it shouldn't have been") + t.Fatalf("tx.QueryRow failed: %v", level) } - committed, err = pool.Transaction(func(conn *pgx.Conn) bool { - var n int64 - err := conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow.Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive expected value: %v", n) - } - return true - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if !committed { - t.Fatal("Transaction was not committed when it should have been") - } - -} - -func TestPoolTransactionIso(t *testing.T) { - t.Parallel() - - pool := createConnPool(t, 2) - defer pool.Close() - - committed, err := pool.TransactionIso("serializable", func(conn *pgx.Conn) bool { - var level string - conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) - - if level != "serializable" { - t.Errorf("Expected to be in isolation level %v but was %v", "serializable", level) - } - return true - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if !committed { - t.Fatal("Transaction was not committed when it should have been") + if level != "serializable" { + t.Errorf("Expected to be in isolation level %v but was %v", "serializable", level) } } diff --git a/tx.go b/tx.go index 144f22a0..c5ddcf2c 100644 --- a/tx.go +++ b/tx.go @@ -5,10 +5,30 @@ import ( "fmt" ) +// Transaction isolation levels +const ( + Serializable = "serializable" + RepeatableRead = "repeatable read" + ReadCommitted = "read committed" + ReadUncommitted = "read uncommitted" +) + +var ErrTxClosed = errors.New("tx is closed") + +// Begin starts a transaction with the default isolation level for the current +// connection. To use a specific isolation level see BeginIso. func (c *Conn) Begin() (*Tx, error) { return c.begin("") } +// BeginIso starts a transaction with isoLevel as the transaction isolation +// level. +// +// Valid isolation levels (and their constants) are: +// serializable (pgx.Serializable) +// repeatable read (pgx.RepeatableRead) +// read committed (pgx.ReadCommitted) +// read uncommitted (pgx.ReadUncommitted) func (c *Conn) BeginIso(isoLevel string) (*Tx, error) { return c.begin(isoLevel) } @@ -29,6 +49,10 @@ func (c *Conn) begin(isoLevel string) (*Tx, error) { return &Tx{conn: c}, nil } +// Tx represents a database transaction. +// +// All Tx methods return ErrTxClosed if Commit or Rollback has already been +// called on the Tx. type Tx struct { pool *ConnPool conn *Conn @@ -38,7 +62,7 @@ type Tx struct { // Commit commits the transaction func (tx *Tx) Commit() error { if tx.closed { - return errors.New("tx is closed") + return ErrTxClosed } _, err := tx.conn.Exec("commit") @@ -46,10 +70,13 @@ func (tx *Tx) Commit() error { return err } -// Rollback rolls back the transaction +// Rollback rolls back the transaction. Rollback will return ErrTxClosed if the +// Tx is already closed, but is otherwise safe to call multiple times. Hence, a +// defer tx.Rollback() is safe even if tx.Commit() will be called first in a +// non-error condition. func (tx *Tx) Rollback() error { if tx.closed { - return errors.New("tx is closed") + return ErrTxClosed } _, err := tx.conn.Exec("rollback") @@ -68,7 +95,7 @@ func (tx *Tx) close() { // Exec delegates to the underlying *Conn func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { if tx.closed { - return CommandTag(""), errors.New("tx is closed") + return CommandTag(""), ErrTxClosed } return tx.conn.Exec(sql, arguments...) @@ -78,7 +105,7 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error - err := errors.New("tx is closed") + err := ErrTxClosed return &Rows{closed: true, err: err}, err } @@ -90,52 +117,3 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { rows, _ := tx.conn.Query(sql, args...) return (*Row)(rows) } - -// Transaction runs f in a transaction. f should return true if the transaction -// should be committed or false if it should be rolled back. Return value committed -// is if the transaction was committed or not. committed should be checked separately -// from err as an explicit rollback is not an error. Transaction will use the default -// isolation level for the current connection. To use a specific isolation level see -// TransactionIso -func (c *Conn) Transaction(f func() bool) (committed bool, err error) { - return c.transaction("", f) -} - -// TransactionIso is the same as Transaction except it takes an isoLevel argument that -// it uses as the transaction isolation level. -// -// Valid isolation levels (and their constants) are: -// serializable (pgx.Serializable) -// repeatable read (pgx.RepeatableRead) -// read committed (pgx.ReadCommitted) -// read uncommitted (pgx.ReadUncommitted) -func (c *Conn) TransactionIso(isoLevel string, f func() bool) (committed bool, err error) { - return c.transaction(isoLevel, f) -} - -func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err error) { - var beginSql string - if isoLevel == "" { - beginSql = "begin" - } else { - beginSql = fmt.Sprintf("begin isolation level %s", isoLevel) - } - - if _, err = c.Exec(beginSql); err != nil { - return - } - defer func() { - if committed && c.TxStatus == 'T' { - _, err = c.Exec("commit") - if err != nil { - committed = false - } - } else { - _, err = c.Exec("rollback") - committed = false - } - }() - - committed = f() - return -} diff --git a/tx_test.go b/tx_test.go index a0e0a805..6fb70719 100644 --- a/tx_test.go +++ b/tx_test.go @@ -114,148 +114,3 @@ func TestBeginIso(t *testing.T) { } } } - -func TestTransaction(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - createSql := ` - create temporary table foo( - id integer, - unique (id) initially deferred - ); - ` - - if _, err := conn.Exec(createSql); err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // Transaction happy path -- it executes function and commits - committed, err := conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - return true - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if !committed { - t.Fatal("Transaction was not committed") - } - - var n int64 - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 1 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - mustExec(t, conn, "truncate foo") - - // It rolls back when passed function returns false - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - return false - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if committed { - t.Fatal("Transaction should not have been committed") - } - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - // it rolls back changes when connection is in error state - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - if _, err := conn.Exec("invalid"); err == nil { - t.Fatal("Exec was supposed to error but didn't") - } - return true - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if committed { - t.Fatal("Transaction was committed when it shouldn't have been") - } - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - // when commit fails - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - mustExec(t, conn, "insert into foo(id) values (1)") - return true - }) - if err == nil { - t.Fatal("Transaction should have failed but didn't") - } - if committed { - t.Fatal("Transaction was committed when it should have failed") - } - - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - // when something in transaction panics - func() { - defer func() { - recover() - }() - - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - panic("stop!") - }) - - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - }() -} - -func TestTransactionIso(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} - for _, iso := range isoLevels { - _, err := conn.TransactionIso(iso, func() bool { - var level string - conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) - if level != iso { - t.Errorf("Expected to be in isolation level %v but was %v", iso, level) - } - return true - }) - if err != nil { - t.Fatalf("Unexpected transaction failure: %v", err) - } - } -}