diff --git a/conn.go b/conn.go index 7b70e045..bec30527 100644 --- a/conn.go +++ b/conn.go @@ -879,55 +879,6 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag } } -// Transaction runs f in a transaction. f should return true if the transaction -// should be committed or false if it should be rolled back. Return value committed -// is if the transaction was committed or not. committed should be checked separately -// from err as an explicit rollback is not an error. Transaction will use the default -// isolation level for the current connection. To use a specific isolation level see -// TransactionIso -func (c *Conn) Transaction(f func() bool) (committed bool, err error) { - return c.transaction("", f) -} - -// TransactionIso is the same as Transaction except it takes an isoLevel argument that -// it uses as the transaction isolation level. -// -// Valid isolation levels (and their constants) are: -// serializable (pgx.Serializable) -// repeatable read (pgx.RepeatableRead) -// read committed (pgx.ReadCommitted) -// read uncommitted (pgx.ReadUncommitted) -func (c *Conn) TransactionIso(isoLevel string, f func() bool) (committed bool, err error) { - return c.transaction(isoLevel, f) -} - -func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err error) { - var beginSql string - if isoLevel == "" { - beginSql = "begin" - } else { - beginSql = fmt.Sprintf("begin isolation level %s", isoLevel) - } - - if _, err = c.Exec(beginSql); err != nil { - return - } - defer func() { - if committed && c.TxStatus == 'T' { - _, err = c.Exec("commit") - if err != nil { - committed = false - } - } else { - _, err = c.Exec("rollback") - committed = false - } - }() - - committed = f() - return -} - // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. diff --git a/conn_pool.go b/conn_pool.go index b24a1b70..c13197d0 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -198,6 +198,41 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +// Begin acquires a connection and begins a transaction on it. When the +// transaction is closed the connection will be automatically released. +func (p *ConnPool) Begin() (*Tx, error) { + c, err := p.Acquire() + if err != nil { + return nil, err + } + + tx, err := c.Begin() + if err != nil { + return nil, err + } + + tx.pool = p + return tx, nil +} + +// BeginIso acquires a connection and begins a transaction in isolation mode iso +// on it. When the transaction is closed the connection will be automatically +// released. +func (p *ConnPool) BeginIso(iso string) (*Tx, error) { + c, err := p.Acquire() + if err != nil { + return nil, err + } + + tx, err := c.BeginIso(iso) + if err != nil { + return nil, err + } + + tx.pool = p + return tx, nil +} + // Transaction acquires a connection, delegates the call to that connection, // and releases the connection. The call signature differs slightly from the // underlying Transaction in that the callback function accepts a *Conn diff --git a/conn_pool_test.go b/conn_pool_test.go index 0954798e..cfc8b451 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -303,6 +303,33 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } } +func TestConnPoolTransaction(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + tx, err := pool.Begin() + if err != nil { + t.Fatalf("pool.Begin failed: %v", err) + } + + var n int32 + err := pool.QueryRow("select 40+$1", 2).Scan(&n) + if err != nil { + t.Fatalf("pool.QueryRow Scan failed: %v", err) + } + + if n != 42 { + t.Errorf("Expected 42, got %d", n) + } + + stats := pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } +} + func TestPoolTransaction(t *testing.T) { t.Parallel() diff --git a/conn_test.go b/conn_test.go index e68f91c9..04e418e7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -712,151 +712,6 @@ func TestPrepareFailure(t *testing.T) { ensureConnValid(t, conn) } -func TestTransaction(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - createSql := ` - create temporary table foo( - id integer, - unique (id) initially deferred - ); - ` - - if _, err := conn.Exec(createSql); err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // Transaction happy path -- it executes function and commits - committed, err := conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - return true - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if !committed { - t.Fatal("Transaction was not committed") - } - - var n int64 - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 1 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - mustExec(t, conn, "truncate foo") - - // It rolls back when passed function returns false - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - return false - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if committed { - t.Fatal("Transaction should not have been committed") - } - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - // it rolls back changes when connection is in error state - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - if _, err := conn.Exec("invalid"); err == nil { - t.Fatal("Exec was supposed to error but didn't") - } - return true - }) - if err != nil { - t.Fatalf("Transaction unexpectedly failed: %v", err) - } - if committed { - t.Fatal("Transaction was committed when it shouldn't have been") - } - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - // when commit fails - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - mustExec(t, conn, "insert into foo(id) values (1)") - return true - }) - if err == nil { - t.Fatal("Transaction should have failed but didn't") - } - if committed { - t.Fatal("Transaction was committed when it should have failed") - } - - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - - // when something in transaction panics - func() { - defer func() { - recover() - }() - - committed, err = conn.Transaction(func() bool { - mustExec(t, conn, "insert into foo(id) values (1)") - panic("stop!") - }) - - err = conn.QueryRow("select count(*) from foo").Scan(&n) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if n != 0 { - t.Fatalf("Did not receive correct number of rows: %v", n) - } - }() -} - -func TestTransactionIso(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} - for _, iso := range isoLevels { - _, err := conn.TransactionIso(iso, func() bool { - var level string - conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) - if level != iso { - t.Errorf("Expected to be in isolation level %v but was %v", iso, level) - } - return true - }) - if err != nil { - t.Fatalf("Unexpected transaction failure: %v", err) - } - } -} - func TestListenNotify(t *testing.T) { t.Parallel() diff --git a/tx.go b/tx.go new file mode 100644 index 00000000..144f22a0 --- /dev/null +++ b/tx.go @@ -0,0 +1,141 @@ +package pgx + +import ( + "errors" + "fmt" +) + +func (c *Conn) Begin() (*Tx, error) { + return c.begin("") +} + +func (c *Conn) BeginIso(isoLevel string) (*Tx, error) { + return c.begin(isoLevel) +} + +func (c *Conn) begin(isoLevel string) (*Tx, error) { + var beginSql string + if isoLevel == "" { + beginSql = "begin" + } else { + beginSql = fmt.Sprintf("begin isolation level %s", isoLevel) + } + + _, err := c.Exec(beginSql) + if err != nil { + return nil, err + } + + return &Tx{conn: c}, nil +} + +type Tx struct { + pool *ConnPool + conn *Conn + closed bool +} + +// Commit commits the transaction +func (tx *Tx) Commit() error { + if tx.closed { + return errors.New("tx is closed") + } + + _, err := tx.conn.Exec("commit") + tx.close() + return err +} + +// Rollback rolls back the transaction +func (tx *Tx) Rollback() error { + if tx.closed { + return errors.New("tx is closed") + } + + _, err := tx.conn.Exec("rollback") + tx.close() + return err +} + +func (tx *Tx) close() { + if tx.pool != nil { + tx.pool.Release(tx.conn) + tx.pool = nil + } + tx.closed = true +} + +// Exec delegates to the underlying *Conn +func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + if tx.closed { + return CommandTag(""), errors.New("tx is closed") + } + + return tx.conn.Exec(sql, arguments...) +} + +// Query delegates to the underlying *Conn +func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { + if tx.closed { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := errors.New("tx is closed") + return &Rows{closed: true, err: err}, err + } + + return tx.conn.Query(sql, args...) +} + +// QueryRow delegates to the underlying *Conn +func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { + rows, _ := tx.conn.Query(sql, args...) + return (*Row)(rows) +} + +// Transaction runs f in a transaction. f should return true if the transaction +// should be committed or false if it should be rolled back. Return value committed +// is if the transaction was committed or not. committed should be checked separately +// from err as an explicit rollback is not an error. Transaction will use the default +// isolation level for the current connection. To use a specific isolation level see +// TransactionIso +func (c *Conn) Transaction(f func() bool) (committed bool, err error) { + return c.transaction("", f) +} + +// TransactionIso is the same as Transaction except it takes an isoLevel argument that +// it uses as the transaction isolation level. +// +// Valid isolation levels (and their constants) are: +// serializable (pgx.Serializable) +// repeatable read (pgx.RepeatableRead) +// read committed (pgx.ReadCommitted) +// read uncommitted (pgx.ReadUncommitted) +func (c *Conn) TransactionIso(isoLevel string, f func() bool) (committed bool, err error) { + return c.transaction(isoLevel, f) +} + +func (c *Conn) transaction(isoLevel string, f func() bool) (committed bool, err error) { + var beginSql string + if isoLevel == "" { + beginSql = "begin" + } else { + beginSql = fmt.Sprintf("begin isolation level %s", isoLevel) + } + + if _, err = c.Exec(beginSql); err != nil { + return + } + defer func() { + if committed && c.TxStatus == 'T' { + _, err = c.Exec("commit") + if err != nil { + committed = false + } + } else { + _, err = c.Exec("rollback") + committed = false + } + }() + + committed = f() + return +} diff --git a/tx_test.go b/tx_test.go new file mode 100644 index 00000000..a0e0a805 --- /dev/null +++ b/tx_test.go @@ -0,0 +1,261 @@ +package pgx_test + +import ( + "github.com/jackc/pgx" + "testing" +) + +func TestTransactionSuccessfulCommit(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin() + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } + + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 1 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + +func TestTransactionSuccessfulRollback(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + tx, err := conn.Begin() + if err != nil { + t.Fatalf("conn.Begin failed: %v", err) + } + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } +} + +func TestBeginIso(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} + for _, iso := range isoLevels { + tx, err := conn.BeginIso(iso) + if err != nil { + t.Fatalf("conn.BeginIso failed: %v", err) + } + + var level string + conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + if level != iso { + t.Errorf("Expected to be in isolation level %v but was %v", iso, level) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + } +} + +func TestTransaction(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + createSql := ` + create temporary table foo( + id integer, + unique (id) initially deferred + ); + ` + + if _, err := conn.Exec(createSql); err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + // Transaction happy path -- it executes function and commits + committed, err := conn.Transaction(func() bool { + mustExec(t, conn, "insert into foo(id) values (1)") + return true + }) + if err != nil { + t.Fatalf("Transaction unexpectedly failed: %v", err) + } + if !committed { + t.Fatal("Transaction was not committed") + } + + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 1 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + mustExec(t, conn, "truncate foo") + + // It rolls back when passed function returns false + committed, err = conn.Transaction(func() bool { + mustExec(t, conn, "insert into foo(id) values (1)") + return false + }) + if err != nil { + t.Fatalf("Transaction unexpectedly failed: %v", err) + } + if committed { + t.Fatal("Transaction should not have been committed") + } + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + // it rolls back changes when connection is in error state + committed, err = conn.Transaction(func() bool { + mustExec(t, conn, "insert into foo(id) values (1)") + if _, err := conn.Exec("invalid"); err == nil { + t.Fatal("Exec was supposed to error but didn't") + } + return true + }) + if err != nil { + t.Fatalf("Transaction unexpectedly failed: %v", err) + } + if committed { + t.Fatal("Transaction was committed when it shouldn't have been") + } + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + // when commit fails + committed, err = conn.Transaction(func() bool { + mustExec(t, conn, "insert into foo(id) values (1)") + mustExec(t, conn, "insert into foo(id) values (1)") + return true + }) + if err == nil { + t.Fatal("Transaction should have failed but didn't") + } + if committed { + t.Fatal("Transaction was committed when it should have failed") + } + + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + + // when something in transaction panics + func() { + defer func() { + recover() + }() + + committed, err = conn.Transaction(func() bool { + mustExec(t, conn, "insert into foo(id) values (1)") + panic("stop!") + }) + + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { + t.Fatalf("Did not receive correct number of rows: %v", n) + } + }() +} + +func TestTransactionIso(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} + for _, iso := range isoLevels { + _, err := conn.TransactionIso(iso, func() bool { + var level string + conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + if level != iso { + t.Errorf("Expected to be in isolation level %v but was %v", iso, level) + } + return true + }) + if err != nil { + t.Fatalf("Unexpected transaction failure: %v", err) + } + } +}