From d1fd222ca574df832934c4b4ada8fd9efd47d25d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 17:58:19 -0500 Subject: [PATCH] Add transaction context support --- conn_pool.go | 6 +-- conn_pool_test.go | 3 +- pgmock/pgmock.go | 25 ++++++++++ stdlib/sql.go | 2 +- stdlib/sql_test.go | 11 +++-- tx.go | 28 +++++++++--- tx_test.go | 112 +++++++++++++++++++++++++++++++++++++++++++-- 7 files changed, 168 insertions(+), 19 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 49de6658..632692de 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -410,7 +410,7 @@ func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExO // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { - return p.BeginEx(nil) + return p.BeginEx(context.Background(), nil) } // Prepare creates a prepared statement on a connection in the pool to test the @@ -499,14 +499,14 @@ func (p *ConnPool) Deallocate(name string) (err error) { // BeginEx acquires a connection and starts a transaction with txOptions // determining the transaction mode. When the transaction is closed the // connection will be automatically released. -func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) { +func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { for { c, err := p.Acquire() if err != nil { return nil, err } - tx, err := c.BeginEx(txOptions) + tx, err := c.BeginEx(ctx, txOptions) if err != nil { alive := c.IsAlive() p.Release(c) diff --git a/conn_pool_test.go b/conn_pool_test.go index 42f37eb1..560ab3ae 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "errors" "fmt" "net" @@ -635,7 +636,7 @@ func TestConnPoolTransactionIso(t *testing.T) { pool := createConnPool(t, 2) defer pool.Close() - tx, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("pool.BeginEx failed: %v", err) } diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 8dccf811..3f1e54f4 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -3,6 +3,7 @@ package pgmock import ( "errors" "fmt" + "io" "net" "reflect" @@ -38,6 +39,9 @@ func (s *Server) ServeOne() error { if err != nil { return err } + defer conn.Close() + + s.Close() backend, err := pgproto3.NewBackend(conn, conn) if err != nil { @@ -167,6 +171,27 @@ func SendMessage(msg pgproto3.BackendMessage) Step { return &sendMessageStep{msg: msg} } +type waitForCloseMessageStep struct{} + +func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { + for { + msg, err := backend.Receive() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if _, ok := msg.(*pgproto3.Terminate); ok { + return nil + } + } +} + +func WaitForClose() Step { + return &waitForCloseMessageStep{} +} + func AcceptUnauthenticatedConnRequestSteps() []Step { return []Step{ ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), diff --git a/stdlib/sql.go b/stdlib/sql.go index 088095ab..a0aa6975 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -267,7 +267,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.AccessMode = pgx.ReadOnly } - return c.conn.BeginEx(&pgxOpts) + return c.conn.BeginEx(ctx, &pgxOpts) } func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index aa3ae3ee..415864cd 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -847,6 +847,7 @@ func TestConnPingContextCancel(t *testing.T) { script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -855,7 +856,7 @@ func TestConnPingContextCancel(t *testing.T) { } defer server.Close() - errChan := make(chan error) + errChan := make(chan error, 1) go func() { errChan <- server.ServeOne() }() @@ -864,7 +865,7 @@ func TestConnPingContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer closeDB(t, db) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -900,6 +901,7 @@ func TestConnPrepareContextCancel(t *testing.T) { pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), pgmock.ExpectMessage(&pgproto3.Sync{}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -917,7 +919,7 @@ func TestConnPrepareContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer closeDB(t, db) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -950,6 +952,7 @@ func TestConnExecContextCancel(t *testing.T) { script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -967,7 +970,7 @@ func TestConnExecContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer closeDB(t, db) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) diff --git a/tx.go b/tx.go index ea804449..f5309468 100644 --- a/tx.go +++ b/tx.go @@ -2,8 +2,10 @@ package pgx import ( "bytes" + "context" "errors" "fmt" + "time" ) type TxIsoLevel string @@ -56,12 +58,13 @@ var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") // Begin starts a transaction with the default transaction mode for the // current connection. To use a specific transaction mode see BeginEx. func (c *Conn) Begin() (*Tx, error) { - return c.BeginEx(nil) + return c.BeginEx(context.Background(), nil) } // BeginEx starts a transaction with txOptions determining the transaction -// mode. -func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { +// mode. Unlike database/sql, the context only affects the begin command. i.e. +// there is no auto-rollback on context cancelation. +func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { var beginSQL string if txOptions == nil { beginSQL = "begin" @@ -81,8 +84,11 @@ func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { beginSQL = buf.String() } - _, err := c.Exec(beginSQL) + _, err := c.ExecEx(ctx, beginSQL, nil) if err != nil { + // begin should never fail unless there is an underlying connection issue or + // a context timeout. In either case, the connection is possibly broken. + c.die(errors.New("failed to begin transaction")) return nil, err } @@ -102,11 +108,16 @@ type Tx struct { // Commit commits the transaction func (tx *Tx) Commit() error { + return tx.CommitEx(context.Background()) +} + +// CommitEx commits the transaction with a context. +func (tx *Tx) CommitEx(ctx context.Context) error { if tx.status != TxStatusInProgress { return ErrTxClosed } - commandTag, err := tx.conn.Exec("commit") + commandTag, err := tx.conn.ExecEx(ctx, "commit", nil) if err == nil && commandTag == "COMMIT" { tx.status = TxStatusCommitSuccess } else if err == nil && commandTag == "ROLLBACK" { @@ -115,6 +126,8 @@ func (tx *Tx) Commit() error { } else { tx.status = TxStatusCommitFailure tx.err = err + // A commit failure leaves the connection in an undefined state + tx.conn.die(errors.New("commit failed")) } if tx.connPool != nil { @@ -133,11 +146,14 @@ func (tx *Tx) Rollback() error { return ErrTxClosed } - _, tx.err = tx.conn.Exec("rollback") + ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) + _, tx.err = tx.conn.ExecEx(ctx, "rollback", nil) if tx.err == nil { tx.status = TxStatusRollbackSuccess } else { tx.status = TxStatusRollbackFailure + // A rollback failure leaves the connection in an undefined state + tx.conn.die(errors.New("rollback failed")) } if tx.connPool != nil { diff --git a/tx_test.go b/tx_test.go index 35abd4eb..b25e1c9f 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,9 +1,14 @@ package pgx_test import ( + "context" + "fmt" "testing" + "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgmock" + "github.com/jackc/pgx/pgproto3" ) func TestTransactionSuccessfulCommit(t *testing.T) { @@ -107,13 +112,13 @@ func TestTxCommitSerializationFailure(t *testing.T) { } defer pool.Exec(`drop table tx_serializable_sums`) - tx1, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("BeginEx failed: %v", err) } defer tx1.Rollback() - tx2, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("BeginEx failed: %v", err) } @@ -190,7 +195,7 @@ func TestBeginExIsoLevels(t *testing.T) { isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.BeginEx(&pgx.TxOptions{IsoLevel: iso}) + tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso}) if err != nil { t.Fatalf("conn.BeginEx failed: %v", err) } @@ -214,7 +219,7 @@ func TestBeginExReadOnly(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - tx, err := conn.BeginEx(&pgx.TxOptions{AccessMode: pgx.ReadOnly}) + tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { t.Fatalf("conn.BeginEx failed: %v", err) } @@ -226,6 +231,105 @@ func TestBeginExReadOnly(t *testing.T) { } } +func TestConnBeginExContextCancel(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatal(err) + } + + conn := mustConnect(t, mockConfig) + + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + + _, err = conn.BeginEx(ctx, nil) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if conn.IsAlive() { + t.Error("expected conn to be dead after BeginEx failure") + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestTxCommitExCancel(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatal(err) + } + + conn := mustConnect(t, mockConfig) + defer conn.Close() + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + err = tx.CommitEx(ctx) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if conn.IsAlive() { + t.Error("expected conn to be dead after CommitEx failure") + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + func TestTxStatus(t *testing.T) { t.Parallel()