Add ConnPool context methods

context
Jack Christensen 2017-02-04 21:10:13 -06:00
parent 37b86083e4
commit 14eedb4fca
5 changed files with 100 additions and 4 deletions

View File

@ -1051,9 +1051,12 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
} }
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
c.closingLock.Lock()
if !c.alive { if !c.alive {
c.closingLock.Unlock()
return 0, nil, ErrDeadConn return 0, nil, ErrDeadConn
} }
c.closingLock.Unlock()
t, err = c.mr.rxMsg() t, err = c.mr.rxMsg()
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package pgx
import ( import (
"errors" "errors"
"golang.org/x/net/context"
"sync" "sync"
"time" "time"
) )
@ -357,6 +358,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
return c.Exec(sql, arguments...) return c.Exec(sql, arguments...)
} }
func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.ExecContext(ctx, sql, arguments...)
}
// Query acquires a connection and delegates the call to that connection. When // Query acquires a connection and delegates the call to that connection. When
// *Rows are closed, the connection is released automatically. // *Rows are closed, the connection is released automatically.
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
@ -377,6 +388,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
return rows, nil return rows, nil
} }
func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
c, err := p.Acquire()
if err != nil {
// Because checking for errors can be deferred to the *Rows, build one with the error
return &Rows{closed: true, err: err}, err
}
rows, err := c.QueryContext(ctx, sql, args...)
if err != nil {
p.Release(c)
return rows, err
}
rows.AfterClose(p.rowsAfterClose)
return rows, nil
}
// QueryRow acquires a connection and delegates the call to that connection. The // QueryRow acquires a connection and delegates the call to that connection. The
// connection is released automatically after Scan is called on the returned // connection is released automatically after Scan is called on the returned
// *Row. // *Row.
@ -385,6 +414,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows) return (*Row)(rows)
} }
func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := p.QueryContext(ctx, sql, args...)
return (*Row)(rows)
}
// Begin acquires a connection and begins a transaction on it. When the // Begin acquires a connection and begins a transaction on it. When the
// transaction is closed the connection will be automatically released. // transaction is closed the connection will be automatically released.
func (p *ConnPool) Begin() (*Tx, error) { func (p *ConnPool) Begin() (*Tx, error) {

12
context-todo.txt Normal file
View File

@ -0,0 +1,12 @@
Add more testing
- stress test style
- pgmock
Add documentation
Add PrepareContext
Add context methods to ConnPool
Add context methods to Tx
Add context support database/sql
Benchmark - possibly cache done channel on Conn

View File

@ -51,8 +51,9 @@ type Rows struct {
unlockConn bool unlockConn bool
closed bool closed bool
ctx context.Context ctx context.Context
doneChan chan struct{} doneChan chan struct{}
closedChan chan bool
} }
func (rows *Rows) FieldDescriptions() []FieldDescription { func (rows *Rows) FieldDescriptions() []FieldDescription {
@ -127,7 +128,7 @@ func (rows *Rows) Close() {
if rows.ctx != nil { if rows.ctx != nil {
select { select {
case <-rows.ctx.Done(): case <-rows.closedChan:
rows.err = rows.ctx.Err() rows.err = rows.ctx.Err()
case rows.doneChan <- struct{}{}: case rows.doneChan <- struct{}{}:
} }
@ -508,12 +509,14 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
doneChan := make(chan struct{}) doneChan := make(chan struct{})
closedChan := make(chan bool)
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.cancelQuery() c.cancelQuery()
c.Close() c.Close()
closedChan <- true
case <-doneChan: case <-doneChan:
} }
}() }()
@ -522,7 +525,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
if err != nil { if err != nil {
select { select {
case <-ctx.Done(): case <-closedChan:
return rows, ctx.Err() return rows, ctx.Err()
case doneChan <- struct{}{}: case doneChan <- struct{}{}:
return rows, err return rows, err
@ -531,6 +534,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
rows.ctx = ctx rows.ctx = ctx
rows.doneChan = doneChan rows.doneChan = doneChan
rows.closedChan = closedChan
return rows, nil return rows, nil
} }

View File

@ -3,6 +3,7 @@ package pgx_test
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/context"
"math/rand" "math/rand"
"testing" "testing"
"time" "time"
@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) {
{"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"listenAndPoolUnlistens", listenAndPoolUnlistens},
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
{"canceledQueryContext", canceledQueryContext},
{"canceledExecContext", canceledExecContext},
} }
var timer *time.Timer var timer *time.Timer
@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
return tx.Commit() return tx.Commit()
} }
func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
rows, err := pool.QueryContext(ctx, "select pg_sleep(5)")
if err == context.Canceled {
return nil
} else if err != nil {
return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err)
}
for rows.Next() {
return errors.New("canceledQueryContext: should never receive row")
}
if rows.Err() != context.Canceled {
return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err())
}
return nil
}
func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
_, err := pool.ExecContext(ctx, "select pg_sleep(5)")
if err != context.Canceled {
return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err)
}
return nil
}