mirror of https://github.com/jackc/pgx.git
Add transaction context support
parent
2df4b1406b
commit
d1fd222ca5
|
@ -410,7 +410,7 @@ func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExO
|
|||
// Begin acquires a connection and begins a transaction on it. When the
|
||||
// transaction is closed the connection will be automatically released.
|
||||
func (p *ConnPool) Begin() (*Tx, error) {
|
||||
return p.BeginEx(nil)
|
||||
return p.BeginEx(context.Background(), nil)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement on a connection in the pool to test the
|
||||
|
@ -499,14 +499,14 @@ func (p *ConnPool) Deallocate(name string) (err error) {
|
|||
// BeginEx acquires a connection and starts a transaction with txOptions
|
||||
// determining the transaction mode. When the transaction is closed the
|
||||
// connection will be automatically released.
|
||||
func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) {
|
||||
func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
||||
for {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := c.BeginEx(txOptions)
|
||||
tx, err := c.BeginEx(ctx, txOptions)
|
||||
if err != nil {
|
||||
alive := c.IsAlive()
|
||||
p.Release(c)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -635,7 +636,7 @@ func TestConnPoolTransactionIso(t *testing.T) {
|
|||
pool := createConnPool(t, 2)
|
||||
defer pool.Close()
|
||||
|
||||
tx, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
if err != nil {
|
||||
t.Fatalf("pool.BeginEx failed: %v", err)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package pgmock
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
|
@ -38,6 +39,9 @@ func (s *Server) ServeOne() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
s.Close()
|
||||
|
||||
backend, err := pgproto3.NewBackend(conn, conn)
|
||||
if err != nil {
|
||||
|
@ -167,6 +171,27 @@ func SendMessage(msg pgproto3.BackendMessage) Step {
|
|||
return &sendMessageStep{msg: msg}
|
||||
}
|
||||
|
||||
type waitForCloseMessageStep struct{}
|
||||
|
||||
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
for {
|
||||
msg, err := backend.Receive()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := msg.(*pgproto3.Terminate); ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WaitForClose() Step {
|
||||
return &waitForCloseMessageStep{}
|
||||
}
|
||||
|
||||
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||
return []Step{
|
||||
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
|
|
|
@ -267,7 +267,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
|||
pgxOpts.AccessMode = pgx.ReadOnly
|
||||
}
|
||||
|
||||
return c.conn.BeginEx(&pgxOpts)
|
||||
return c.conn.BeginEx(ctx, &pgxOpts)
|
||||
}
|
||||
|
||||
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
||||
|
|
|
@ -847,6 +847,7 @@ func TestConnPingContextCancel(t *testing.T) {
|
|||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: ";"}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
|
@ -855,7 +856,7 @@ func TestConnPingContextCancel(t *testing.T) {
|
|||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error)
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
@ -864,7 +865,7 @@ func TestConnPingContextCancel(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
// defer closeDB(t, db) // mock DB doesn't close correctly yet
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
|
@ -900,6 +901,7 @@ func TestConnPrepareContextCancel(t *testing.T) {
|
|||
pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
|
||||
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
|
@ -917,7 +919,7 @@ func TestConnPrepareContextCancel(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
// defer closeDB(t, db) // mock DB doesn't close correctly yet
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
|
@ -950,6 +952,7 @@ func TestConnExecContextCancel(t *testing.T) {
|
|||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
|
@ -967,7 +970,7 @@ func TestConnExecContextCancel(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
// defer closeDB(t, db) // mock DB doesn't close correctly yet
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
|
|
28
tx.go
28
tx.go
|
@ -2,8 +2,10 @@ package pgx
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TxIsoLevel string
|
||||
|
@ -56,12 +58,13 @@ var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
|
|||
// Begin starts a transaction with the default transaction mode for the
|
||||
// current connection. To use a specific transaction mode see BeginEx.
|
||||
func (c *Conn) Begin() (*Tx, error) {
|
||||
return c.BeginEx(nil)
|
||||
return c.BeginEx(context.Background(), nil)
|
||||
}
|
||||
|
||||
// BeginEx starts a transaction with txOptions determining the transaction
|
||||
// mode.
|
||||
func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) {
|
||||
// mode. Unlike database/sql, the context only affects the begin command. i.e.
|
||||
// there is no auto-rollback on context cancelation.
|
||||
func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) {
|
||||
var beginSQL string
|
||||
if txOptions == nil {
|
||||
beginSQL = "begin"
|
||||
|
@ -81,8 +84,11 @@ func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) {
|
|||
beginSQL = buf.String()
|
||||
}
|
||||
|
||||
_, err := c.Exec(beginSQL)
|
||||
_, err := c.ExecEx(ctx, beginSQL, nil)
|
||||
if err != nil {
|
||||
// begin should never fail unless there is an underlying connection issue or
|
||||
// a context timeout. In either case, the connection is possibly broken.
|
||||
c.die(errors.New("failed to begin transaction"))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -102,11 +108,16 @@ type Tx struct {
|
|||
|
||||
// Commit commits the transaction
|
||||
func (tx *Tx) Commit() error {
|
||||
return tx.CommitEx(context.Background())
|
||||
}
|
||||
|
||||
// CommitEx commits the transaction with a context.
|
||||
func (tx *Tx) CommitEx(ctx context.Context) error {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return ErrTxClosed
|
||||
}
|
||||
|
||||
commandTag, err := tx.conn.Exec("commit")
|
||||
commandTag, err := tx.conn.ExecEx(ctx, "commit", nil)
|
||||
if err == nil && commandTag == "COMMIT" {
|
||||
tx.status = TxStatusCommitSuccess
|
||||
} else if err == nil && commandTag == "ROLLBACK" {
|
||||
|
@ -115,6 +126,8 @@ func (tx *Tx) Commit() error {
|
|||
} else {
|
||||
tx.status = TxStatusCommitFailure
|
||||
tx.err = err
|
||||
// A commit failure leaves the connection in an undefined state
|
||||
tx.conn.die(errors.New("commit failed"))
|
||||
}
|
||||
|
||||
if tx.connPool != nil {
|
||||
|
@ -133,11 +146,14 @@ func (tx *Tx) Rollback() error {
|
|||
return ErrTxClosed
|
||||
}
|
||||
|
||||
_, tx.err = tx.conn.Exec("rollback")
|
||||
ctx, _ := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
_, tx.err = tx.conn.ExecEx(ctx, "rollback", nil)
|
||||
if tx.err == nil {
|
||||
tx.status = TxStatusRollbackSuccess
|
||||
} else {
|
||||
tx.status = TxStatusRollbackFailure
|
||||
// A rollback failure leaves the connection in an undefined state
|
||||
tx.conn.die(errors.New("rollback failed"))
|
||||
}
|
||||
|
||||
if tx.connPool != nil {
|
||||
|
|
112
tx_test.go
112
tx_test.go
|
@ -1,9 +1,14 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgmock"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
func TestTransactionSuccessfulCommit(t *testing.T) {
|
||||
|
@ -107,13 +112,13 @@ func TestTxCommitSerializationFailure(t *testing.T) {
|
|||
}
|
||||
defer pool.Exec(`drop table tx_serializable_sums`)
|
||||
|
||||
tx1, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
if err != nil {
|
||||
t.Fatalf("BeginEx failed: %v", err)
|
||||
}
|
||||
defer tx1.Rollback()
|
||||
|
||||
tx2, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
||||
if err != nil {
|
||||
t.Fatalf("BeginEx failed: %v", err)
|
||||
}
|
||||
|
@ -190,7 +195,7 @@ func TestBeginExIsoLevels(t *testing.T) {
|
|||
|
||||
isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
|
||||
for _, iso := range isoLevels {
|
||||
tx, err := conn.BeginEx(&pgx.TxOptions{IsoLevel: iso})
|
||||
tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso})
|
||||
if err != nil {
|
||||
t.Fatalf("conn.BeginEx failed: %v", err)
|
||||
}
|
||||
|
@ -214,7 +219,7 @@ func TestBeginExReadOnly(t *testing.T) {
|
|||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
tx, err := conn.BeginEx(&pgx.TxOptions{AccessMode: pgx.ReadOnly})
|
||||
tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly})
|
||||
if err != nil {
|
||||
t.Fatalf("conn.BeginEx failed: %v", err)
|
||||
}
|
||||
|
@ -226,6 +231,105 @@ func TestBeginExReadOnly(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnBeginExContextCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn := mustConnect(t, mockConfig)
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
|
||||
_, err = conn.BeginEx(ctx, nil)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("expected conn to be dead after BeginEx failure")
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxCommitExCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
||||
script.Steps = append(script.Steps,
|
||||
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
|
||||
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
|
||||
pgmock.WaitForClose(),
|
||||
)
|
||||
|
||||
server, err := pgmock.NewServer(script)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- server.ServeOne()
|
||||
}()
|
||||
|
||||
mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn := mustConnect(t, mockConfig)
|
||||
defer conn.Close()
|
||||
|
||||
tx, err := conn.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
err = tx.CommitEx(ctx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
if conn.IsAlive() {
|
||||
t.Error("expected conn to be dead after CommitEx failure")
|
||||
}
|
||||
|
||||
if err := <-errChan; err != nil {
|
||||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue