diff --git a/conn.go b/conn.go index 45bb9441..f7c06014 100644 --- a/conn.go +++ b/conn.go @@ -1051,9 +1051,12 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { + c.closingLock.Lock() if !c.alive { + c.closingLock.Unlock() return 0, nil, ErrDeadConn } + c.closingLock.Unlock() t, err = c.mr.rxMsg() if err != nil { diff --git a/conn_pool.go b/conn_pool.go index 6d04565d..50b9d588 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -2,6 +2,7 @@ package pgx import ( "errors" + "golang.org/x/net/context" "sync" "time" ) @@ -357,6 +358,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } +func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + var c *Conn + if c, err = p.Acquire(); err != nil { + return + } + defer p.Release(c) + + return c.ExecContext(ctx, sql, arguments...) +} + // Query acquires a connection and delegates the call to that connection. When // *Rows are closed, the connection is released automatically. func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { @@ -377,6 +388,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, nil } +func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *Rows, build one with the error + return &Rows{closed: true, err: err}, err + } + + rows, err := c.QueryContext(ctx, sql, args...) + if err != nil { + p.Release(c) + return rows, err + } + + rows.AfterClose(p.rowsAfterClose) + + return rows, nil +} + // QueryRow acquires a connection and delegates the call to that connection. The // connection is released automatically after Scan is called on the returned // *Row. @@ -385,6 +414,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := p.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} + // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { diff --git a/context-todo.txt b/context-todo.txt new file mode 100644 index 00000000..b5a20d0a --- /dev/null +++ b/context-todo.txt @@ -0,0 +1,12 @@ +Add more testing +- stress test style +- pgmock + +Add documentation + +Add PrepareContext +Add context methods to ConnPool +Add context methods to Tx +Add context support database/sql + +Benchmark - possibly cache done channel on Conn diff --git a/query.go b/query.go index fc3f405b..3ded881d 100644 --- a/query.go +++ b/query.go @@ -51,8 +51,9 @@ type Rows struct { unlockConn bool closed bool - ctx context.Context - doneChan chan struct{} + ctx context.Context + doneChan chan struct{} + closedChan chan bool } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -127,7 +128,7 @@ func (rows *Rows) Close() { if rows.ctx != nil { select { - case <-rows.ctx.Done(): + case <-rows.closedChan: rows.err = rows.ctx.Err() case rows.doneChan <- struct{}{}: } @@ -508,12 +509,14 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { doneChan := make(chan struct{}) + closedChan := make(chan bool) go func() { select { case <-ctx.Done(): c.cancelQuery() c.Close() + closedChan <- true case <-doneChan: } }() @@ -522,7 +525,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} if err != nil { select { - case <-ctx.Done(): + case <-closedChan: return rows, ctx.Err() case doneChan <- struct{}{}: return rows, err @@ -531,6 +534,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} rows.ctx = ctx rows.doneChan = doneChan + rows.closedChan = closedChan return rows, nil } diff --git a/stress_test.go b/stress_test.go index 150d13c8..d22d9d6b 100644 --- a/stress_test.go +++ b/stress_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "errors" "fmt" + "golang.org/x/net/context" "math/rand" "testing" "time" @@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) { {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, + {"canceledQueryContext", canceledQueryContext}, + {"canceledExecContext", canceledExecContext}, } var timer *time.Timer @@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { return tx.Commit() } + +func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") + if err == context.Canceled { + return nil + } else if err != nil { + return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) + } + + for rows.Next() { + return errors.New("canceledQueryContext: should never receive row") + } + + if rows.Err() != context.Canceled { + return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) + } + + return nil +} + +func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + _, err := pool.ExecContext(ctx, "select pg_sleep(5)") + if err != context.Canceled { + return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) + } + + return nil +}