Merge branch 'context' into v3-experimental

v3-experimental
Jack Christensen 2017-02-12 15:13:47 -06:00
commit 7fbff4a495
16 changed files with 1035 additions and 205 deletions

View File

@ -51,6 +51,8 @@ install:
- go get -u github.com/shopspring/decimal
- go get -u gopkg.in/inconshreveable/log15.v2
- go get -u github.com/jackc/fake
- go get -u golang.org/x/net/context
- go get -u github.com/jackc/pgmock/pgmsg
script:
- go test -v -race -short ./...

11
conn-lock-todo.txt Normal file
View File

@ -0,0 +1,11 @@
Extract all locking state into a separate struct that will encapsulate locking and state change behavior.
This struct should add or subsume at least the following:
* alive
* closingLock
* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one)
* busy
* lock/unlock
* Tx in-progress
* Rows in-progress
* ConnPool checked-out or checked-in - maybe include reference to conn pool

489
conn.go
View File

@ -8,6 +8,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"golang.org/x/net/context"
"io"
"net"
"net/url"
@ -17,9 +18,17 @@ import (
"regexp"
"strconv"
"strings"
"sync/atomic"
"time"
)
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// DialFunc is a function that can be used to connect to a PostgreSQL server
type DialFunc func(network, addr string) (net.Conn, error)
@ -39,13 +48,28 @@ type ConnConfig struct {
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
}
func (cc *ConnConfig) networkAddress() (network, address string) {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
}
return network, address
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple
// goroutines.
type Conn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
lastActivityTime time.Time // the last time the connection was used
reader *bufio.Reader // buffered reader to improve read performance
conn net.Conn // the underlying TCP or unix domain socket connection
lastActivityTime time.Time // the last time the connection was used
wbuf [1024]byte
writeBuf WriteBuf
pid int32 // backend pid
@ -57,17 +81,26 @@ type Conn struct {
preparedStatements map[string]*PreparedStatement
channels map[string]struct{}
notifications []*Notification
alive bool
causeOfDeath error
logger Logger
logLevel int
mr msgReader
fp *fastpath
pgsqlAfInet *byte
pgsqlAfInet6 *byte
busy bool
poolResetCount int
preallocatedRows []Rows
status int32 // One of connStatus* constants
causeOfDeath error
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
cancelQueryInProgress int32
cancelQueryCompleted chan struct{}
// context support
ctxInProgress bool
doneChan chan struct{}
closedChan chan error
}
// PreparedStatement is a description of a prepared statement
@ -194,17 +227,7 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsql
}
}
network := "tcp"
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(c.config.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = c.config.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
}
}
network, address := c.config.networkAddress()
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
}
@ -238,15 +261,18 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
defer func() {
if c != nil && err != nil {
c.conn.Close()
c.alive = false
atomic.StoreInt32(&c.status, connStatusClosed)
}
}()
c.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*PreparedStatement)
c.channels = make(map[string]struct{})
c.alive = true
atomic.StoreInt32(&c.status, connStatusIdle)
c.lastActivityTime = time.Now()
c.cancelQueryCompleted = make(chan struct{}, 1)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) {
@ -257,8 +283,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
}
c.reader = bufio.NewReader(c.conn)
c.mr.reader = c.reader
c.mr.reader = bufio.NewReader(c.conn)
msg := newStartupMessage()
@ -389,14 +414,17 @@ func (c *Conn) PID() int32 {
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close() (err error) {
if !c.IsAlive() {
return nil
for {
status := atomic.LoadInt32(&c.status)
if status < connStatusIdle {
return nil
}
if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) {
break
}
}
wbuf := newWriteBuf(c, 'X')
wbuf.closeMsg()
_, err = c.conn.Write(wbuf.buf)
_, err = c.conn.Write([]byte{'X', 0, 0, 0, 4})
c.die(errors.New("Closed"))
if c.shouldLog(LogLevelInfo) {
@ -614,12 +642,36 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
return c.PrepareExContext(context.Background(), name, sql, opts)
}
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
err = c.initContext(ctx)
if err != nil {
return nil, err
}
ps, err = c.prepareEx(name, sql, opts)
err = c.termContext(err)
return ps, err
}
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
@ -659,6 +711,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
c.die(err)
return nil, err
}
c.readyForQuery = false
ps = &PreparedStatement{Name: name, SQL: sql}
@ -673,7 +726,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
}
switch t {
case parseComplete:
case parameterDescription:
ps.ParameterOIDs = c.rxParameterDescription(r)
@ -687,7 +739,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
ps.FieldDescriptions[i].DataTypeName = t.Name
ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
}
case noData:
case readyForQuery:
c.rxReadyForQuery(r)
@ -705,7 +756,29 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
}
// Deallocate released a prepared statement
func (c *Conn) Deallocate(name string) (err error) {
func (c *Conn) Deallocate(name string) error {
return c.deallocateContext(context.Background(), name)
}
// TODO - consider making this public
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return err
}
err = c.initContext(ctx)
if err != nil {
return err
}
defer func() {
err = c.termContext(err)
}()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
delete(c.preparedStatements, name)
// close
@ -776,6 +849,17 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
return notification, nil
}
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
cancelFn()
return nil, err
}
cancelFn()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
stopTime := time.Now().Add(timeout)
for {
@ -835,7 +919,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
}
// Wait until there is a byte available before continuing onto the normal msg reading path
_, err = c.reader.Peek(1)
_, err = c.mr.reader.Peek(1)
if err != nil {
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
if err, ok := err.(*net.OpError); ok && err.Timeout() {
@ -868,7 +952,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
}
func (c *Conn) IsAlive() bool {
return c.alive
return atomic.LoadInt32(&c.status) >= connStatusIdle
}
func (c *Conn) CauseOfDeath() error {
@ -883,6 +967,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
}
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
if len(args) == 0 {
wbuf := newWriteBuf(c, 'Q')
@ -894,6 +981,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
c.die(err)
return err
}
c.readyForQuery = false
return nil
}
@ -911,6 +999,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments))
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
// bind
wbuf := newWriteBuf(c, 'B')
wbuf.WriteByte(0)
@ -958,6 +1050,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
if err != nil {
c.die(err)
}
c.readyForQuery = false
return err
}
@ -965,91 +1058,52 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
if err = c.lock(); err != nil {
return commandTag, err
}
startTime := time.Now()
c.lastActivityTime = startTime
defer func() {
if err == nil {
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
}
} else {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
}
}
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
var softErr error
for {
var t byte
var r *msgReader
t, r, err = c.rxMsg()
if err != nil {
return commandTag, err
}
switch t {
case readyForQuery:
c.rxReadyForQuery(r)
return commandTag, softErr
case rowDescription:
case dataRow:
case bindComplete:
case commandComplete:
commandTag = CommandTag(r.readCString())
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e
}
}
}
return c.ExecContext(context.Background(), sql, arguments...)
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages
// is the same regardless of when they occur.
// authentication or query response. The response to these messages is the same
// regardless of when they occur. It also ignores messages that are only
// meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
switch t {
case 'S':
c.rxParameterStatus(r)
return nil
case bindComplete:
case commandComplete:
case dataRow:
case emptyQueryResponse:
case errorResponse:
return c.rxErrorResponse(r)
case noData:
case noticeResponse:
return nil
case emptyQueryResponse:
return nil
case notificationResponse:
c.rxNotificationResponse(r)
return nil
case parameterDescription:
case parseComplete:
case readyForQuery:
c.rxReadyForQuery(r)
case rowDescription:
case 'S':
c.rxParameterStatus(r)
default:
return fmt.Errorf("Received unknown message type: %c", t)
}
return nil
}
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
if !c.alive {
if atomic.LoadInt32(&c.status) < connStatusIdle {
return 0, nil, ErrDeadConn
}
t, err = c.mr.rxMsg()
if err != nil {
c.die(err)
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
c.die(err)
}
}
c.lastActivityTime = time.Now()
@ -1150,6 +1204,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) {
}
func (c *Conn) rxReadyForQuery(r *msgReader) {
c.readyForQuery = true
c.txStatus = r.readByte()
}
@ -1230,25 +1285,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
}
func (c *Conn) die(err error) {
c.alive = false
atomic.StoreInt32(&c.status, connStatusClosed)
c.causeOfDeath = err
c.conn.Close()
}
func (c *Conn) lock() error {
if c.busy {
return ErrConnBusy
if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) {
return nil
}
c.busy = true
return nil
return ErrConnBusy
}
func (c *Conn) unlock() error {
if !c.busy {
return errors.New("unlock conn that is not busy")
if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) {
return nil
}
c.busy = false
return nil
return errors.New("unlock conn that is not busy")
}
func (c *Conn) shouldLog(lvl int) bool {
@ -1286,3 +1339,229 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) {
func quoteIdentifier(s string) string {
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
}
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
// error if unable to deliver the cancel request, but lack of an error does not
// ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
func (c *Conn) cancelQuery() {
if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
panic("cancelQuery when cancelQueryInProgress")
}
if err := c.conn.SetDeadline(time.Now()); err != nil {
c.Close() // Close connection if unable to set deadline
return
}
doCancel := func() error {
network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
if err != nil {
return err
}
defer cancelConn.Close()
// If server doesn't process cancellation request in bounded time then abort.
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
if err != nil {
return err
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
}
return nil
}
go func() {
err := doCancel()
if err != nil {
c.Close() // Something is very wrong. Terminate the connection.
}
c.cancelQueryCompleted <- struct{}{}
}()
}
func (c *Conn) Ping() error {
return c.PingContext(context.Background())
}
func (c *Conn) PingContext(ctx context.Context) error {
_, err := c.ExecContext(ctx, ";")
return err
}
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return "", err
}
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
if err = c.lock(); err != nil {
return commandTag, err
}
startTime := time.Now()
c.lastActivityTime = startTime
defer func() {
if err == nil {
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
}
} else {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
}
}
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
var softErr error
for {
var t byte
var r *msgReader
t, r, err = c.rxMsg()
if err != nil {
return commandTag, err
}
switch t {
case readyForQuery:
c.rxReadyForQuery(r)
return commandTag, softErr
case commandComplete:
commandTag = CommandTag(r.readCString())
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e
}
}
}
return commandTag, err
}
func (c *Conn) initContext(ctx context.Context) error {
if c.ctxInProgress {
return errors.New("ctx already in progress")
}
if ctx.Done() == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.ctxInProgress = true
go c.contextHandler(ctx)
return nil
}
func (c *Conn) termContext(opErr error) error {
if !c.ctxInProgress {
return opErr
}
var err error
select {
case err = <-c.closedChan:
if opErr == nil {
err = nil
}
case c.doneChan <- struct{}{}:
err = opErr
}
c.ctxInProgress = false
return err
}
func (c *Conn) contextHandler(ctx context.Context) {
select {
case <-ctx.Done():
c.cancelQuery()
c.closedChan <- ctx.Err()
case <-c.doneChan:
}
}
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
return nil
}
select {
case <-c.cancelQueryCompleted:
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
if err := c.conn.SetDeadline(time.Time{}); err != nil {
c.Close() // Close connection if unable to disable deadline
return err
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery {
t, r, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case errorResponse:
pgErr := c.rxErrorResponse(r)
if pgErr.Severity == "FATAL" {
return pgErr
}
default:
err = c.processContextFreeMsg(t, r)
if err != nil {
return err
}
}
}
return nil
}

View File

@ -2,6 +2,7 @@ package pgx
import (
"errors"
"golang.org/x/net/context"
"sync"
"time"
)
@ -181,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
// Release gives up use of a connection.
func (p *ConnPool) Release(conn *Conn) {
if conn.ctxInProgress {
panic("should never release when context is in progress")
}
if conn.txStatus != 'I' {
conn.Exec("rollback")
}
@ -357,6 +362,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
return c.Exec(sql, arguments...)
}
func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.ExecContext(ctx, sql, arguments...)
}
// Query acquires a connection and delegates the call to that connection. When
// *Rows are closed, the connection is released automatically.
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
@ -377,6 +392,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
return rows, nil
}
func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
c, err := p.Acquire()
if err != nil {
// Because checking for errors can be deferred to the *Rows, build one with the error
return &Rows{closed: true, err: err}, err
}
rows, err := c.QueryContext(ctx, sql, args...)
if err != nil {
p.Release(c)
return rows, err
}
rows.AfterClose(p.rowsAfterClose)
return rows, nil
}
// QueryRow acquires a connection and delegates the call to that connection. The
// connection is released automatically after Scan is called on the returned
// *Row.
@ -385,6 +418,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows)
}
func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := p.QueryContext(ctx, sql, args...)
return (*Row)(rows)
}
// Begin acquires a connection and begins a transaction on it. When the
// transaction is closed the connection will be automatically released.
func (p *ConnPool) Begin() (*Tx, error) {

View File

@ -3,6 +3,7 @@ package pgx_test
import (
"crypto/tls"
"fmt"
"golang.org/x/net/context"
"net"
"os"
"reflect"
@ -816,6 +817,64 @@ func TestExecFailure(t *testing.T) {
}
}
func TestExecContextWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
}
}
func TestExecContextFailureWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
t.Fatal("Expected SQL syntax error")
}
rows, _ := conn.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
}
}
func TestExecContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
if err != context.Canceled {
t.Fatal("Expected context.Canceled err, got %v", err)
}
ensureConnValid(t, conn)
}
func TestPrepare(t *testing.T) {
t.Parallel()

12
context-todo.txt Normal file
View File

@ -0,0 +1,12 @@
Add more testing
- stress test style
- pgmock
Add documentation
Add PrepareContext
Add context methods to ConnPool
Add context methods to Tx
Add context support database/sql
Benchmark - possibly cache done channel on Conn

View File

@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() {
ct.conn.rxReadyForQuery(r)
close(ct.readerErrChan)
return
case commandComplete:
case errorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
default:

View File

@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg {
}
func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) {
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
wbuf := newWriteBuf(f.cn, 'F') // function call
wbuf.WriteInt32(int32(oid)) // function object id
wbuf.WriteInt16(1) // # of argument format codes

View File

@ -21,7 +21,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio
return conn
}
func closeConn(t testing.TB, conn *pgx.Conn) {
err := conn.Close()
if err != nil {

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"io"
"net"
)
// msgReader is a helper that reads values from a PostgreSQL message.
@ -16,11 +17,6 @@ type msgReader struct {
shouldLog func(lvl int) bool
}
// Err returns any error that the msgReader has experienced
func (r *msgReader) Err() error {
return r.err
}
// fatal tells rc that a Fatal error has occurred
func (r *msgReader) fatal(err error) {
if r.shouldLog(LogLevelTrace) {
@ -40,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}
_, err := r.reader.Discard(int(r.msgBytesRemaining))
n, err := r.reader.Discard(int(r.msgBytesRemaining))
r.msgBytesRemaining -= int32(n)
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
}
b, err := r.reader.Peek(5)
if err != nil {
r.fatal(err)
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
msgType := b[0]
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
// Try to preload bufio.Reader with entire message
b, err = r.reader.Peek(5 + int(payloadSize))
if err != nil && err != bufio.ErrBufferFull {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
r.fatal(err)
}
return 0, err
}
r.msgBytesRemaining = payloadSize
r.reader.Discard(5)
return msgType, nil
}

189
msg_reader_test.go Normal file
View File

@ -0,0 +1,189 @@
package pgx
import (
"bufio"
"net"
"testing"
"time"
"github.com/jackc/pgmock/pgmsg"
)
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
t.Parallel()
tests := []struct {
msgType byte
payloadSize int32
buffered bool
}{
{1, 50, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 24000, false},
{9, 4000, true},
{1, 1500, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 14000, false},
{9, 0, true},
{1, 500, true},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for _, tt := range tests {
_, err = conn.Write([]byte{tt.msgType})
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, int(tt.payloadSize))
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
for i, tt := range tests {
msgType, err := mr.rxMsg()
if err != nil {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
if msgType != tt.msgType {
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
}
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
}
}
}
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
testCount := 10000
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for i := 0; i < testCount; i++ {
msgType := byte(i)
_, err = conn.Write([]byte{msgType})
if err != nil {
t.Fatal(err)
}
msgSize := i % 4000
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, msgSize)
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
i := 0
for {
msgType, err := mr.rxMsg()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
continue
} else {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
}
expectedMsgType := byte(i)
if msgType != expectedMsgType {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
}
expectedMsgSize := i % 4000
payload := mr.readBytes(mr.msgBytesRemaining)
if mr.err != nil {
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
}
if len(payload) != expectedMsgSize {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
}
i++
if i == testCount {
break
}
}
}

143
query.go
View File

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"golang.org/x/net/context"
"time"
)
@ -55,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription {
return rows.fields
}
func (rows *Rows) close() {
// Close closes the rows, making the connection ready for use again. It is safe
// to call Close after rows is already closed.
func (rows *Rows) Close() {
if rows.closed {
return
}
@ -67,6 +70,8 @@ func (rows *Rows) close() {
rows.closed = true
rows.err = rows.conn.termContext(rows.err)
if rows.err == nil {
if rows.conn.shouldLog(LogLevelInfo) {
endTime := time.Now()
@ -81,63 +86,10 @@ func (rows *Rows) close() {
}
}
func (rows *Rows) readUntilReadyForQuery() {
for {
t, r, err := rows.conn.rxMsg()
if err != nil {
rows.close()
return
}
switch t {
case readyForQuery:
rows.conn.rxReadyForQuery(r)
rows.close()
return
case rowDescription:
case dataRow:
case commandComplete:
case bindComplete:
case errorResponse:
err = rows.conn.rxErrorResponse(r)
if rows.err == nil {
rows.err = err
}
default:
err = rows.conn.processContextFreeMsg(t, r)
if err != nil {
rows.close()
return
}
}
}
}
// Close closes the rows, making the connection ready for use again. It is safe
// to call Close after rows is already closed.
func (rows *Rows) Close() {
if rows.closed {
return
}
rows.readUntilReadyForQuery()
rows.close()
}
func (rows *Rows) Err() error {
return rows.err
}
// abort signals that the query was not successfully sent to the server.
// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
func (rows *Rows) abort(err error) {
if rows.err != nil {
return
}
rows.err = err
rows.close()
}
// Fatal signals an error occurred after the query was sent to the server. It
// closes the rows automatically.
func (rows *Rows) Fatal(err error) {
@ -169,10 +121,6 @@ func (rows *Rows) Next() bool {
}
switch t {
case readyForQuery:
rows.conn.rxReadyForQuery(r)
rows.close()
return false
case dataRow:
fieldCount := r.readInt16()
if int(fieldCount) != len(rows.fields) {
@ -183,7 +131,9 @@ func (rows *Rows) Next() bool {
rows.mr = r
return true
case commandComplete:
case bindComplete:
rows.Close()
return false
default:
err = rows.conn.processContextFreeMsg(t, r)
if err != nil {
@ -441,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
// be returned in an error state. So it is allowed to ignore the error returned
// from Query and handle it in *Rows.
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
c.lastActivityTime = time.Now()
rows := c.getRows(sql, args)
if err := c.lock(); err != nil {
rows.abort(err)
return rows, err
}
rows.unlockConn = true
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.Prepare("", sql)
if err != nil {
rows.abort(err)
return rows, rows.err
}
}
rows.sql = ps.SQL
rows.fields = ps.FieldDescriptions
err := c.sendPreparedQuery(ps, args...)
if err != nil {
rows.abort(err)
}
return rows, rows.err
return c.QueryContext(context.Background(), sql, args...)
}
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
@ -492,3 +417,51 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
rows, _ := c.Query(sql, args...)
return (*Row)(rows)
}
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
c.lastActivityTime = time.Now()
rows = c.getRows(sql, args)
if err := c.lock(); err != nil {
rows.Fatal(err)
return rows, err
}
rows.unlockConn = true
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.PrepareExContext(ctx, "", sql, nil)
if err != nil {
rows.Fatal(err)
return rows, rows.err
}
}
rows.sql = ps.SQL
rows.fields = ps.FieldDescriptions
err = c.initContext(ctx)
if err != nil {
rows.Fatal(err)
return rows, err
}
err = c.sendPreparedQuery(ps, args...)
if err != nil {
rows.Fatal(err)
err = c.termContext(err)
}
return rows, err
}
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := c.QueryContext(ctx, sql, args...)
return (*Row)(rows)
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"database/sql"
"fmt"
"golang.org/x/net/context"
"strings"
"testing"
"time"
@ -1412,3 +1413,165 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 42::integer")
if err != nil {
t.Fatal(err)
}
var result, rowCount int
for rows.Next() {
err = rows.Scan(&result)
if err != nil {
t.Fatal(err)
}
rowCount++
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
if rowCount != 1 {
t.Fatalf("Expected 1 row, got %d", rowCount)
}
if result != 42 {
t.Fatalf("Expected result 42, got %d", result)
}
ensureConnValid(t, conn)
}
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
if err != nil {
t.Fatal(err)
}
var result, rowCount int
for rows.Next() {
err = rows.Scan(&result)
if err != nil {
t.Fatal(err)
}
rowCount++
}
if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", rows.Err())
}
if rowCount != 9 {
t.Fatalf("Expected 9 rows, got %d", rowCount)
}
if result != 10 {
t.Fatalf("Expected result 10, got %d", result)
}
ensureConnValid(t, conn)
}
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
if err != nil {
t.Fatal(err)
}
for rows.Next() {
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
}
if rows.Err() != context.Canceled {
t.Fatal("Expected context.Canceled error, got %v", rows.Err())
}
ensureConnValid(t, conn)
}
func TestQueryRowContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
if err != nil {
t.Fatal(err)
}
if result != 42 {
t.Fatalf("Expected result 42, got %d", result)
}
ensureConnValid(t, conn)
}
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", err)
}
ensureConnValid(t, conn)
}
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
var result []byte
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
if err != context.Canceled {
t.Fatal("Expected context.Canceled error, got %v", err)
}
ensureConnValid(t, conn)
}

View File

@ -289,7 +289,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *
}
// Wait until there is a byte available before continuing onto the normal msg reading path
_, err = rc.c.reader.Peek(1)
_, err = rc.c.mr.reader.Peek(1)
if err != nil {
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
if err, ok := err.(*net.OpError); ok && err.Timeout() {
@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
rows := rc.c.getRows(sql, nil)
if err := rc.c.lock(); err != nil {
rows.abort(err)
rows.Fatal(err)
return rows, err
}
rows.unlockConn = true
err := rc.c.sendSimpleQuery(sql)
if err != nil {
rows.abort(err)
rows.Fatal(err)
}
var t byte
@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
// only Oids. Not much we can do about this.
default:
if e := rc.c.processContextFreeMsg(t, r); e != nil {
rows.abort(e)
rows.Fatal(e)
return rows, e
}
}

View File

@ -44,6 +44,7 @@
package stdlib
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
return c.queryPrepared("", argsV)
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
ps, err := c.conn.PrepareExContext(ctx, "", query, nil)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
}
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
@ -226,6 +242,22 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
return &Rows{rows: rows}, nil
}
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := namedValueToInterface(argsV)
rows, err := c.conn.QueryContext(ctx, name, args...)
if err != nil {
fmt.Println(err)
return nil, err
}
return &Rows{rows: rows}, nil
}
// Anything that isn't a database/sql compatible type needs to be forced to
// text format so that pgx.Rows.Values doesn't decode it into a native type
// (e.g. []int32)
@ -318,6 +350,18 @@ func valueToInterface(argsV []driver.Value) []interface{} {
return args
}
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
args := make([]interface{}, 0, len(argsV))
for _, v := range argsV {
if v.Value != nil {
args = append(args, v.Value.(interface{}))
} else {
args = append(args, nil)
}
}
return args
}
type Tx struct {
conn *pgx.Conn
}

View File

@ -3,6 +3,7 @@ package pgx_test
import (
"errors"
"fmt"
"golang.org/x/net/context"
"math/rand"
"testing"
"time"
@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) {
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
{"canceledQueryContext", canceledQueryContext},
{"canceledExecContext", canceledExecContext},
}
var timer *time.Timer
@ -63,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
action := actions[rand.Intn(len(actions))]
err := action.fn(pool, n)
if err != nil {
errChan <- err
errChan <- fmt.Errorf("%s: %v", action.name, err)
break
}
}
@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
return tx.Commit()
}
func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
if err == context.Canceled {
return nil
} else if err != nil {
return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
}
for rows.Next() {
return errors.New("should never receive row")
}
if rows.Err() != context.Canceled {
return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
}
return nil
}
func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
if err != context.Canceled {
return fmt.Errorf("Expected context.Canceled error, got %v", err)
}
return nil
}