Add transaction context support

batch-wip
Jack Christensen 2017-05-20 17:58:19 -05:00
parent 2df4b1406b
commit d1fd222ca5
7 changed files with 168 additions and 19 deletions

View File

@ -410,7 +410,7 @@ func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExO
// Begin acquires a connection and begins a transaction on it. When the // Begin acquires a connection and begins a transaction on it. When the
// transaction is closed the connection will be automatically released. // transaction is closed the connection will be automatically released.
func (p *ConnPool) Begin() (*Tx, error) { func (p *ConnPool) Begin() (*Tx, error) {
return p.BeginEx(nil) return p.BeginEx(context.Background(), nil)
} }
// Prepare creates a prepared statement on a connection in the pool to test the // Prepare creates a prepared statement on a connection in the pool to test the
@ -499,14 +499,14 @@ func (p *ConnPool) Deallocate(name string) (err error) {
// BeginEx acquires a connection and starts a transaction with txOptions // BeginEx acquires a connection and starts a transaction with txOptions
// determining the transaction mode. When the transaction is closed the // determining the transaction mode. When the transaction is closed the
// connection will be automatically released. // connection will be automatically released.
func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) { func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
for { for {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return nil, err return nil, err
} }
tx, err := c.BeginEx(txOptions) tx, err := c.BeginEx(ctx, txOptions)
if err != nil { if err != nil {
alive := c.IsAlive() alive := c.IsAlive()
p.Release(c) p.Release(c)

View File

@ -1,6 +1,7 @@
package pgx_test package pgx_test
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -635,7 +636,7 @@ func TestConnPoolTransactionIso(t *testing.T) {
pool := createConnPool(t, 2) pool := createConnPool(t, 2)
defer pool.Close() defer pool.Close()
tx, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil { if err != nil {
t.Fatalf("pool.BeginEx failed: %v", err) t.Fatalf("pool.BeginEx failed: %v", err)
} }

View File

@ -3,6 +3,7 @@ package pgmock
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"reflect" "reflect"
@ -38,6 +39,9 @@ func (s *Server) ServeOne() error {
if err != nil { if err != nil {
return err return err
} }
defer conn.Close()
s.Close()
backend, err := pgproto3.NewBackend(conn, conn) backend, err := pgproto3.NewBackend(conn, conn)
if err != nil { if err != nil {
@ -167,6 +171,27 @@ func SendMessage(msg pgproto3.BackendMessage) Step {
return &sendMessageStep{msg: msg} return &sendMessageStep{msg: msg}
} }
type waitForCloseMessageStep struct{}
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
for {
msg, err := backend.Receive()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
}
}
func WaitForClose() Step {
return &waitForCloseMessageStep{}
}
func AcceptUnauthenticatedConnRequestSteps() []Step { func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{ return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),

View File

@ -267,7 +267,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
pgxOpts.AccessMode = pgx.ReadOnly pgxOpts.AccessMode = pgx.ReadOnly
} }
return c.conn.BeginEx(&pgxOpts) return c.conn.BeginEx(ctx, &pgxOpts)
} }
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {

View File

@ -847,6 +847,7 @@ func TestConnPingContextCancel(t *testing.T) {
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps, script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), pgmock.ExpectMessage(&pgproto3.Query{String: ";"}),
pgmock.WaitForClose(),
) )
server, err := pgmock.NewServer(script) server, err := pgmock.NewServer(script)
@ -855,7 +856,7 @@ func TestConnPingContextCancel(t *testing.T) {
} }
defer server.Close() defer server.Close()
errChan := make(chan error) errChan := make(chan error, 1)
go func() { go func() {
errChan <- server.ServeOne() errChan <- server.ServeOne()
}() }()
@ -864,7 +865,7 @@ func TestConnPingContextCancel(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("sql.Open failed: %v", err) t.Fatalf("sql.Open failed: %v", err)
} }
// defer closeDB(t, db) // mock DB doesn't close correctly yet defer closeDB(t, db)
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
@ -900,6 +901,7 @@ func TestConnPrepareContextCancel(t *testing.T) {
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}),
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
pgmock.ExpectMessage(&pgproto3.Sync{}), pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.WaitForClose(),
) )
server, err := pgmock.NewServer(script) server, err := pgmock.NewServer(script)
@ -917,7 +919,7 @@ func TestConnPrepareContextCancel(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("sql.Open failed: %v", err) t.Fatalf("sql.Open failed: %v", err)
} }
// defer closeDB(t, db) // mock DB doesn't close correctly yet defer closeDB(t, db)
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
@ -950,6 +952,7 @@ func TestConnExecContextCancel(t *testing.T) {
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps, script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}),
pgmock.WaitForClose(),
) )
server, err := pgmock.NewServer(script) server, err := pgmock.NewServer(script)
@ -967,7 +970,7 @@ func TestConnExecContextCancel(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("sql.Open failed: %v", err) t.Fatalf("sql.Open failed: %v", err)
} }
// defer closeDB(t, db) // mock DB doesn't close correctly yet defer closeDB(t, db)
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)

28
tx.go
View File

@ -2,8 +2,10 @@ package pgx
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"time"
) )
type TxIsoLevel string type TxIsoLevel string
@ -56,12 +58,13 @@ var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
// Begin starts a transaction with the default transaction mode for the // Begin starts a transaction with the default transaction mode for the
// current connection. To use a specific transaction mode see BeginEx. // current connection. To use a specific transaction mode see BeginEx.
func (c *Conn) Begin() (*Tx, error) { func (c *Conn) Begin() (*Tx, error) {
return c.BeginEx(nil) return c.BeginEx(context.Background(), nil)
} }
// BeginEx starts a transaction with txOptions determining the transaction // BeginEx starts a transaction with txOptions determining the transaction
// mode. // mode. Unlike database/sql, the context only affects the begin command. i.e.
func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { // there is no auto-rollback on context cancelation.
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
var beginSQL string var beginSQL string
if txOptions == nil { if txOptions == nil {
beginSQL = "begin" beginSQL = "begin"
@ -81,8 +84,11 @@ func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) {
beginSQL = buf.String() beginSQL = buf.String()
} }
_, err := c.Exec(beginSQL) _, err := c.ExecEx(ctx, beginSQL, nil)
if err != nil { if err != nil {
// begin should never fail unless there is an underlying connection issue or
// a context timeout. In either case, the connection is possibly broken.
c.die(errors.New("failed to begin transaction"))
return nil, err return nil, err
} }
@ -102,11 +108,16 @@ type Tx struct {
// Commit commits the transaction // Commit commits the transaction
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
return tx.CommitEx(context.Background())
}
// CommitEx commits the transaction with a context.
func (tx *Tx) CommitEx(ctx context.Context) error {
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return ErrTxClosed return ErrTxClosed
} }
commandTag, err := tx.conn.Exec("commit") commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
if err == nil && commandTag == "COMMIT" { if err == nil && commandTag == "COMMIT" {
tx.status = TxStatusCommitSuccess tx.status = TxStatusCommitSuccess
} else if err == nil && commandTag == "ROLLBACK" { } else if err == nil && commandTag == "ROLLBACK" {
@ -115,6 +126,8 @@ func (tx *Tx) Commit() error {
} else { } else {
tx.status = TxStatusCommitFailure tx.status = TxStatusCommitFailure
tx.err = err tx.err = err
// A commit failure leaves the connection in an undefined state
tx.conn.die(errors.New("commit failed"))
} }
if tx.connPool != nil { if tx.connPool != nil {
@ -133,11 +146,14 @@ func (tx *Tx) Rollback() error {
return ErrTxClosed return ErrTxClosed
} }
_, tx.err = tx.conn.Exec("rollback") ctx, _ := context.WithTimeout(context.Background(), 15*time.Second)
_, tx.err = tx.conn.ExecEx(ctx, "rollback", nil)
if tx.err == nil { if tx.err == nil {
tx.status = TxStatusRollbackSuccess tx.status = TxStatusRollbackSuccess
} else { } else {
tx.status = TxStatusRollbackFailure tx.status = TxStatusRollbackFailure
// A rollback failure leaves the connection in an undefined state
tx.conn.die(errors.New("rollback failed"))
} }
if tx.connPool != nil { if tx.connPool != nil {

View File

@ -1,9 +1,14 @@
package pgx_test package pgx_test
import ( import (
"context"
"fmt"
"testing" "testing"
"time"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"github.com/jackc/pgx/pgmock"
"github.com/jackc/pgx/pgproto3"
) )
func TestTransactionSuccessfulCommit(t *testing.T) { func TestTransactionSuccessfulCommit(t *testing.T) {
@ -107,13 +112,13 @@ func TestTxCommitSerializationFailure(t *testing.T) {
} }
defer pool.Exec(`drop table tx_serializable_sums`) defer pool.Exec(`drop table tx_serializable_sums`)
tx1, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil { if err != nil {
t.Fatalf("BeginEx failed: %v", err) t.Fatalf("BeginEx failed: %v", err)
} }
defer tx1.Rollback() defer tx1.Rollback()
tx2, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
if err != nil { if err != nil {
t.Fatalf("BeginEx failed: %v", err) t.Fatalf("BeginEx failed: %v", err)
} }
@ -190,7 +195,7 @@ func TestBeginExIsoLevels(t *testing.T) {
isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
for _, iso := range isoLevels { for _, iso := range isoLevels {
tx, err := conn.BeginEx(&pgx.TxOptions{IsoLevel: iso}) tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso})
if err != nil { if err != nil {
t.Fatalf("conn.BeginEx failed: %v", err) t.Fatalf("conn.BeginEx failed: %v", err)
} }
@ -214,7 +219,7 @@ func TestBeginExReadOnly(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn) defer closeConn(t, conn)
tx, err := conn.BeginEx(&pgx.TxOptions{AccessMode: pgx.ReadOnly}) tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly})
if err != nil { if err != nil {
t.Fatalf("conn.BeginEx failed: %v", err) t.Fatalf("conn.BeginEx failed: %v", err)
} }
@ -226,6 +231,105 @@ func TestBeginExReadOnly(t *testing.T) {
} }
} }
func TestConnBeginExContextCancel(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error, 1)
go func() {
errChan <- server.ServeOne()
}()
mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatal(err)
}
conn := mustConnect(t, mockConfig)
ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond)
_, err = conn.BeginEx(ctx, nil)
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}
if conn.IsAlive() {
t.Error("expected conn to be dead after BeginEx failure")
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestTxCommitExCancel(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
script.Steps = append(script.Steps,
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
pgmock.WaitForClose(),
)
server, err := pgmock.NewServer(script)
if err != nil {
t.Fatal(err)
}
defer server.Close()
errChan := make(chan error, 1)
go func() {
errChan <- server.ServeOne()
}()
mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
if err != nil {
t.Fatal(err)
}
conn := mustConnect(t, mockConfig)
defer conn.Close()
tx, err := conn.Begin()
if err != nil {
t.Fatal(err)
}
ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond)
err = tx.CommitEx(ctx)
if err != context.DeadlineExceeded {
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
}
if conn.IsAlive() {
t.Error("expected conn to be dead after CommitEx failure")
}
if err := <-errChan; err != nil {
t.Errorf("mock server err: %v", err)
}
}
func TestTxStatus(t *testing.T) { func TestTxStatus(t *testing.T) {
t.Parallel() t.Parallel()