mirror of https://github.com/jackc/pgx.git
Merge branch 'context' into v3-experimental
commit
7fbff4a495
|
@ -51,6 +51,8 @@ install:
|
|||
- go get -u github.com/shopspring/decimal
|
||||
- go get -u gopkg.in/inconshreveable/log15.v2
|
||||
- go get -u github.com/jackc/fake
|
||||
- go get -u golang.org/x/net/context
|
||||
- go get -u github.com/jackc/pgmock/pgmsg
|
||||
|
||||
script:
|
||||
- go test -v -race -short ./...
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
Extract all locking state into a separate struct that will encapsulate locking and state change behavior.
|
||||
|
||||
This struct should add or subsume at least the following:
|
||||
* alive
|
||||
* closingLock
|
||||
* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one)
|
||||
* busy
|
||||
* lock/unlock
|
||||
* Tx in-progress
|
||||
* Rows in-progress
|
||||
* ConnPool checked-out or checked-in - maybe include reference to conn pool
|
489
conn.go
489
conn.go
|
@ -8,6 +8,7 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
|
@ -17,9 +18,17 @@ import (
|
|||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
connStatusUninitialized = iota
|
||||
connStatusClosed
|
||||
connStatusIdle
|
||||
connStatusBusy
|
||||
)
|
||||
|
||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||
type DialFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
|
@ -39,13 +48,28 @@ type ConnConfig struct {
|
|||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) networkAddress() (network, address string) {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
}
|
||||
}
|
||||
|
||||
return network, address
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
// Use ConnPool to manage access to multiple database connections from multiple
|
||||
// goroutines.
|
||||
type Conn struct {
|
||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
lastActivityTime time.Time // the last time the connection was used
|
||||
reader *bufio.Reader // buffered reader to improve read performance
|
||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
lastActivityTime time.Time // the last time the connection was used
|
||||
wbuf [1024]byte
|
||||
writeBuf WriteBuf
|
||||
pid int32 // backend pid
|
||||
|
@ -57,17 +81,26 @@ type Conn struct {
|
|||
preparedStatements map[string]*PreparedStatement
|
||||
channels map[string]struct{}
|
||||
notifications []*Notification
|
||||
alive bool
|
||||
causeOfDeath error
|
||||
logger Logger
|
||||
logLevel int
|
||||
mr msgReader
|
||||
fp *fastpath
|
||||
pgsqlAfInet *byte
|
||||
pgsqlAfInet6 *byte
|
||||
busy bool
|
||||
poolResetCount int
|
||||
preallocatedRows []Rows
|
||||
|
||||
status int32 // One of connStatus* constants
|
||||
causeOfDeath error
|
||||
|
||||
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
||||
cancelQueryInProgress int32
|
||||
cancelQueryCompleted chan struct{}
|
||||
|
||||
// context support
|
||||
ctxInProgress bool
|
||||
doneChan chan struct{}
|
||||
closedChan chan error
|
||||
}
|
||||
|
||||
// PreparedStatement is a description of a prepared statement
|
||||
|
@ -194,17 +227,7 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsql
|
|||
}
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
if _, err := os.Stat(c.config.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = c.config.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||
}
|
||||
}
|
||||
network, address := c.config.networkAddress()
|
||||
if c.config.Dial == nil {
|
||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||
}
|
||||
|
@ -238,15 +261,18 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
defer func() {
|
||||
if c != nil && err != nil {
|
||||
c.conn.Close()
|
||||
c.alive = false
|
||||
atomic.StoreInt32(&c.status, connStatusClosed)
|
||||
}
|
||||
}()
|
||||
|
||||
c.RuntimeParams = make(map[string]string)
|
||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||
c.channels = make(map[string]struct{})
|
||||
c.alive = true
|
||||
atomic.StoreInt32(&c.status, connStatusIdle)
|
||||
c.lastActivityTime = time.Now()
|
||||
c.cancelQueryCompleted = make(chan struct{}, 1)
|
||||
c.doneChan = make(chan struct{})
|
||||
c.closedChan = make(chan error)
|
||||
|
||||
if tlsConfig != nil {
|
||||
if c.shouldLog(LogLevelDebug) {
|
||||
|
@ -257,8 +283,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
}
|
||||
}
|
||||
|
||||
c.reader = bufio.NewReader(c.conn)
|
||||
c.mr.reader = c.reader
|
||||
c.mr.reader = bufio.NewReader(c.conn)
|
||||
|
||||
msg := newStartupMessage()
|
||||
|
||||
|
@ -389,14 +414,17 @@ func (c *Conn) PID() int32 {
|
|||
// Close closes a connection. It is safe to call Close on a already closed
|
||||
// connection.
|
||||
func (c *Conn) Close() (err error) {
|
||||
if !c.IsAlive() {
|
||||
return nil
|
||||
for {
|
||||
status := atomic.LoadInt32(&c.status)
|
||||
if status < connStatusIdle {
|
||||
return nil
|
||||
}
|
||||
if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
wbuf := newWriteBuf(c, 'X')
|
||||
wbuf.closeMsg()
|
||||
|
||||
_, err = c.conn.Write(wbuf.buf)
|
||||
_, err = c.conn.Write([]byte{'X', 0, 0, 0, 4})
|
||||
|
||||
c.die(errors.New("Closed"))
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
|
@ -614,12 +642,36 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
|
||||
// concern for if the statement has already been prepared.
|
||||
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||
return c.PrepareExContext(context.Background(), name, sql, opts)
|
||||
}
|
||||
|
||||
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps, err = c.prepareEx(name, sql, opts)
|
||||
err = c.termContext(err)
|
||||
return ps, err
|
||||
}
|
||||
|
||||
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||
if name != "" {
|
||||
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
|
||||
return ps, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelError) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
|
@ -659,6 +711,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
c.die(err)
|
||||
return nil, err
|
||||
}
|
||||
c.readyForQuery = false
|
||||
|
||||
ps = &PreparedStatement{Name: name, SQL: sql}
|
||||
|
||||
|
@ -673,7 +726,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
}
|
||||
|
||||
switch t {
|
||||
case parseComplete:
|
||||
case parameterDescription:
|
||||
ps.ParameterOIDs = c.rxParameterDescription(r)
|
||||
|
||||
|
@ -687,7 +739,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
ps.FieldDescriptions[i].DataTypeName = t.Name
|
||||
ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
|
||||
}
|
||||
case noData:
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
|
||||
|
@ -705,7 +756,29 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
}
|
||||
|
||||
// Deallocate released a prepared statement
|
||||
func (c *Conn) Deallocate(name string) (err error) {
|
||||
func (c *Conn) Deallocate(name string) error {
|
||||
return c.deallocateContext(context.Background(), name)
|
||||
}
|
||||
|
||||
// TODO - consider making this public
|
||||
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(c.preparedStatements, name)
|
||||
|
||||
// close
|
||||
|
@ -776,6 +849,17 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||
return notification, nil
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
|
||||
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
|
||||
cancelFn()
|
||||
return nil, err
|
||||
}
|
||||
cancelFn()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stopTime := time.Now().Add(timeout)
|
||||
|
||||
for {
|
||||
|
@ -835,7 +919,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
|
|||
}
|
||||
|
||||
// Wait until there is a byte available before continuing onto the normal msg reading path
|
||||
_, err = c.reader.Peek(1)
|
||||
_, err = c.mr.reader.Peek(1)
|
||||
if err != nil {
|
||||
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
||||
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
||||
|
@ -868,7 +952,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
|
|||
}
|
||||
|
||||
func (c *Conn) IsAlive() bool {
|
||||
return c.alive
|
||||
return atomic.LoadInt32(&c.status) >= connStatusIdle
|
||||
}
|
||||
|
||||
func (c *Conn) CauseOfDeath() error {
|
||||
|
@ -883,6 +967,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
|||
}
|
||||
|
||||
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
wbuf := newWriteBuf(c, 'Q')
|
||||
|
@ -894,6 +981,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
|||
c.die(err)
|
||||
return err
|
||||
}
|
||||
c.readyForQuery = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -911,6 +999,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments))
|
||||
}
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// bind
|
||||
wbuf := newWriteBuf(c, 'B')
|
||||
wbuf.WriteByte(0)
|
||||
|
@ -958,6 +1050,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
if err != nil {
|
||||
c.die(err)
|
||||
}
|
||||
c.readyForQuery = false
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -965,91 +1058,52 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
if err = c.lock(); err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
c.lastActivityTime = startTime
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||
}
|
||||
} else {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var softErr error
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
return commandTag, softErr
|
||||
case rowDescription:
|
||||
case dataRow:
|
||||
case bindComplete:
|
||||
case commandComplete:
|
||||
commandTag = CommandTag(r.readCString())
|
||||
default:
|
||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
}
|
||||
}
|
||||
}
|
||||
return c.ExecContext(context.Background(), sql, arguments...)
|
||||
}
|
||||
|
||||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages
|
||||
// is the same regardless of when they occur.
|
||||
// authentication or query response. The response to these messages is the same
|
||||
// regardless of when they occur. It also ignores messages that are only
|
||||
// meaningful in a given context. These messages can occur due to a context
|
||||
// deadline interrupting message processing. For example, an interrupted query
|
||||
// may have left DataRow messages on the wire.
|
||||
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
|
||||
switch t {
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
return nil
|
||||
case bindComplete:
|
||||
case commandComplete:
|
||||
case dataRow:
|
||||
case emptyQueryResponse:
|
||||
case errorResponse:
|
||||
return c.rxErrorResponse(r)
|
||||
case noData:
|
||||
case noticeResponse:
|
||||
return nil
|
||||
case emptyQueryResponse:
|
||||
return nil
|
||||
case notificationResponse:
|
||||
c.rxNotificationResponse(r)
|
||||
return nil
|
||||
case parameterDescription:
|
||||
case parseComplete:
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
case rowDescription:
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Received unknown message type: %c", t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
||||
if !c.alive {
|
||||
if atomic.LoadInt32(&c.status) < connStatusIdle {
|
||||
return 0, nil, ErrDeadConn
|
||||
}
|
||||
|
||||
t, err = c.mr.rxMsg()
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
c.die(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.lastActivityTime = time.Now()
|
||||
|
@ -1150,6 +1204,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) {
|
|||
}
|
||||
|
||||
func (c *Conn) rxReadyForQuery(r *msgReader) {
|
||||
c.readyForQuery = true
|
||||
c.txStatus = r.readByte()
|
||||
}
|
||||
|
||||
|
@ -1230,25 +1285,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
|
|||
}
|
||||
|
||||
func (c *Conn) die(err error) {
|
||||
c.alive = false
|
||||
atomic.StoreInt32(&c.status, connStatusClosed)
|
||||
c.causeOfDeath = err
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) lock() error {
|
||||
if c.busy {
|
||||
return ErrConnBusy
|
||||
if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) {
|
||||
return nil
|
||||
}
|
||||
c.busy = true
|
||||
return nil
|
||||
return ErrConnBusy
|
||||
}
|
||||
|
||||
func (c *Conn) unlock() error {
|
||||
if !c.busy {
|
||||
return errors.New("unlock conn that is not busy")
|
||||
if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) {
|
||||
return nil
|
||||
}
|
||||
c.busy = false
|
||||
return nil
|
||||
return errors.New("unlock conn that is not busy")
|
||||
}
|
||||
|
||||
func (c *Conn) shouldLog(lvl int) bool {
|
||||
|
@ -1286,3 +1339,229 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) {
|
|||
func quoteIdentifier(s string) string {
|
||||
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
|
||||
}
|
||||
|
||||
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
|
||||
// error if unable to deliver the cancel request, but lack of an error does not
|
||||
// ensure that the query was canceled. As specified in the documentation, there
|
||||
// is no way to be sure a query was canceled. See
|
||||
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
|
||||
func (c *Conn) cancelQuery() {
|
||||
if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
|
||||
panic("cancelQuery when cancelQueryInProgress")
|
||||
}
|
||||
|
||||
if err := c.conn.SetDeadline(time.Now()); err != nil {
|
||||
c.Close() // Close connection if unable to set deadline
|
||||
return
|
||||
}
|
||||
|
||||
doCancel := func() error {
|
||||
network, address := c.config.networkAddress()
|
||||
cancelConn, err := c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancelConn.Close()
|
||||
|
||||
// If server doesn't process cancellation request in bounded time then abort.
|
||||
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 16)
|
||||
binary.BigEndian.PutUint32(buf[0:4], 16)
|
||||
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
||||
_, err = cancelConn.Write(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cancelConn.Read(buf)
|
||||
if err != io.EOF {
|
||||
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := doCancel()
|
||||
if err != nil {
|
||||
c.Close() // Something is very wrong. Terminate the connection.
|
||||
}
|
||||
c.cancelQueryCompleted <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Conn) Ping() error {
|
||||
return c.PingContext(context.Background())
|
||||
}
|
||||
|
||||
func (c *Conn) PingContext(ctx context.Context) error {
|
||||
_, err := c.ExecContext(ctx, ";")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err = c.lock(); err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
c.lastActivityTime = startTime
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||
}
|
||||
} else {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var softErr error
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *msgReader
|
||||
t, r, err = c.rxMsg()
|
||||
if err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
return commandTag, softErr
|
||||
case commandComplete:
|
||||
commandTag = CommandTag(r.readCString())
|
||||
default:
|
||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||
softErr = e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
func (c *Conn) initContext(ctx context.Context) error {
|
||||
if c.ctxInProgress {
|
||||
return errors.New("ctx already in progress")
|
||||
}
|
||||
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
c.ctxInProgress = true
|
||||
|
||||
go c.contextHandler(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) termContext(opErr error) error {
|
||||
if !c.ctxInProgress {
|
||||
return opErr
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
select {
|
||||
case err = <-c.closedChan:
|
||||
if opErr == nil {
|
||||
err = nil
|
||||
}
|
||||
case c.doneChan <- struct{}{}:
|
||||
err = opErr
|
||||
}
|
||||
|
||||
c.ctxInProgress = false
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) contextHandler(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancelQuery()
|
||||
c.closedChan <- ctx.Err()
|
||||
case <-c.doneChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
|
||||
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.cancelQueryCompleted:
|
||||
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
|
||||
if err := c.conn.SetDeadline(time.Time{}); err != nil {
|
||||
c.Close() // Close connection if unable to disable deadline
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ensureConnectionReadyForQuery() error {
|
||||
for !c.readyForQuery {
|
||||
t, r, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case errorResponse:
|
||||
pgErr := c.rxErrorResponse(r)
|
||||
if pgErr.Severity == "FATAL" {
|
||||
return pgErr
|
||||
}
|
||||
default:
|
||||
err = c.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
38
conn_pool.go
38
conn_pool.go
|
@ -2,6 +2,7 @@ package pgx
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"golang.org/x/net/context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -181,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
|||
|
||||
// Release gives up use of a connection.
|
||||
func (p *ConnPool) Release(conn *Conn) {
|
||||
if conn.ctxInProgress {
|
||||
panic("should never release when context is in progress")
|
||||
}
|
||||
|
||||
if conn.txStatus != 'I' {
|
||||
conn.Exec("rollback")
|
||||
}
|
||||
|
@ -357,6 +362,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
|
|||
return c.Exec(sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.ExecContext(ctx, sql, arguments...)
|
||||
}
|
||||
|
||||
// Query acquires a connection and delegates the call to that connection. When
|
||||
// *Rows are closed, the connection is released automatically.
|
||||
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
|
@ -377,6 +392,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||
return rows, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
return &Rows{closed: true, err: err}, err
|
||||
}
|
||||
|
||||
rows, err := c.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
p.Release(c)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.AfterClose(p.rowsAfterClose)
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryRow acquires a connection and delegates the call to that connection. The
|
||||
// connection is released automatically after Scan is called on the returned
|
||||
// *Row.
|
||||
|
@ -385,6 +418,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
|
|||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||
rows, _ := p.QueryContext(ctx, sql, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
// Begin acquires a connection and begins a transaction on it. When the
|
||||
// transaction is closed the connection will be automatically released.
|
||||
func (p *ConnPool) Begin() (*Tx, error) {
|
||||
|
|
59
conn_test.go
59
conn_test.go
|
@ -3,6 +3,7 @@ package pgx_test
|
|||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -816,6 +817,64 @@ func TestExecFailure(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
|
||||
rows, _ := conn.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
|
||||
if err != context.Canceled {
|
||||
t.Fatal("Expected context.Canceled err, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
Add more testing
|
||||
- stress test style
|
||||
- pgmock
|
||||
|
||||
Add documentation
|
||||
|
||||
Add PrepareContext
|
||||
Add context methods to ConnPool
|
||||
Add context methods to Tx
|
||||
Add context support database/sql
|
||||
|
||||
Benchmark - possibly cache done channel on Conn
|
|
@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() {
|
|||
ct.conn.rxReadyForQuery(r)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case commandComplete:
|
||||
case errorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||
default:
|
||||
|
|
|
@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg {
|
|||
}
|
||||
|
||||
func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) {
|
||||
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wbuf := newWriteBuf(f.cn, 'F') // function call
|
||||
wbuf.WriteInt32(int32(oid)) // function object id
|
||||
wbuf.WriteInt16(1) // # of argument format codes
|
||||
|
|
|
@ -21,7 +21,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio
|
|||
return conn
|
||||
}
|
||||
|
||||
|
||||
func closeConn(t testing.TB, conn *pgx.Conn) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
|
@ -16,11 +17,6 @@ type msgReader struct {
|
|||
shouldLog func(lvl int) bool
|
||||
}
|
||||
|
||||
// Err returns any error that the msgReader has experienced
|
||||
func (r *msgReader) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// fatal tells rc that a Fatal error has occurred
|
||||
func (r *msgReader) fatal(err error) {
|
||||
if r.shouldLog(LogLevelTrace) {
|
||||
|
@ -40,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) {
|
|||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
_, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
n, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
r.msgBytesRemaining -= int32(n)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(5)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
msgType := b[0]
|
||||
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
|
||||
// Try to preload bufio.Reader with entire message
|
||||
b, err = r.reader.Peek(5 + int(payloadSize))
|
||||
if err != nil && err != bufio.ErrBufferFull {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.msgBytesRemaining = payloadSize
|
||||
r.reader.Discard(5)
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgmock/pgmsg"
|
||||
)
|
||||
|
||||
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
msgType byte
|
||||
payloadSize int32
|
||||
buffered bool
|
||||
}{
|
||||
{1, 50, true},
|
||||
{2, 0, true},
|
||||
{3, 500, true},
|
||||
{4, 1050, true},
|
||||
{5, 1500, true},
|
||||
{6, 1500, true},
|
||||
{7, 4000, true},
|
||||
{8, 24000, false},
|
||||
{9, 4000, true},
|
||||
{1, 1500, true},
|
||||
{2, 0, true},
|
||||
{3, 500, true},
|
||||
{4, 1050, true},
|
||||
{5, 1500, true},
|
||||
{6, 1500, true},
|
||||
{7, 4000, true},
|
||||
{8, 14000, false},
|
||||
{9, 0, true},
|
||||
{1, 500, true},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
var bigEndian pgmsg.BigEndianBuf
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
_, err = conn.Write([]byte{tt.msgType})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
payload := make([]byte, int(tt.payloadSize))
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mr := &msgReader{
|
||||
reader: bufio.NewReader(conn),
|
||||
shouldLog: func(int) bool { return false },
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
msgType, err := mr.rxMsg()
|
||||
if err != nil {
|
||||
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if msgType != tt.msgType {
|
||||
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
|
||||
}
|
||||
|
||||
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
|
||||
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
testCount := 10000
|
||||
|
||||
go func() {
|
||||
var bigEndian pgmsg.BigEndianBuf
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for i := 0; i < testCount; i++ {
|
||||
msgType := byte(i)
|
||||
|
||||
_, err = conn.Write([]byte{msgType})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msgSize := i % 4000
|
||||
|
||||
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
payload := make([]byte, msgSize)
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mr := &msgReader{
|
||||
reader: bufio.NewReader(conn),
|
||||
shouldLog: func(int) bool { return false },
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
|
||||
i := 0
|
||||
for {
|
||||
msgType, err := mr.rxMsg()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
continue
|
||||
} else {
|
||||
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
expectedMsgType := byte(i)
|
||||
if msgType != expectedMsgType {
|
||||
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
|
||||
}
|
||||
|
||||
expectedMsgSize := i % 4000
|
||||
payload := mr.readBytes(mr.msgBytesRemaining)
|
||||
if mr.err != nil {
|
||||
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
|
||||
}
|
||||
if len(payload) != expectedMsgSize {
|
||||
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
|
||||
}
|
||||
|
||||
i++
|
||||
if i == testCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
143
query.go
143
query.go
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -55,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription {
|
|||
return rows.fields
|
||||
}
|
||||
|
||||
func (rows *Rows) close() {
|
||||
// Close closes the rows, making the connection ready for use again. It is safe
|
||||
// to call Close after rows is already closed.
|
||||
func (rows *Rows) Close() {
|
||||
if rows.closed {
|
||||
return
|
||||
}
|
||||
|
@ -67,6 +70,8 @@ func (rows *Rows) close() {
|
|||
|
||||
rows.closed = true
|
||||
|
||||
rows.err = rows.conn.termContext(rows.err)
|
||||
|
||||
if rows.err == nil {
|
||||
if rows.conn.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
|
@ -81,63 +86,10 @@ func (rows *Rows) close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (rows *Rows) readUntilReadyForQuery() {
|
||||
for {
|
||||
t, r, err := rows.conn.rxMsg()
|
||||
if err != nil {
|
||||
rows.close()
|
||||
return
|
||||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
rows.conn.rxReadyForQuery(r)
|
||||
rows.close()
|
||||
return
|
||||
case rowDescription:
|
||||
case dataRow:
|
||||
case commandComplete:
|
||||
case bindComplete:
|
||||
case errorResponse:
|
||||
err = rows.conn.rxErrorResponse(r)
|
||||
if rows.err == nil {
|
||||
rows.err = err
|
||||
}
|
||||
default:
|
||||
err = rows.conn.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
rows.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the rows, making the connection ready for use again. It is safe
|
||||
// to call Close after rows is already closed.
|
||||
func (rows *Rows) Close() {
|
||||
if rows.closed {
|
||||
return
|
||||
}
|
||||
rows.readUntilReadyForQuery()
|
||||
rows.close()
|
||||
}
|
||||
|
||||
func (rows *Rows) Err() error {
|
||||
return rows.err
|
||||
}
|
||||
|
||||
// abort signals that the query was not successfully sent to the server.
|
||||
// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
|
||||
func (rows *Rows) abort(err error) {
|
||||
if rows.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rows.err = err
|
||||
rows.close()
|
||||
}
|
||||
|
||||
// Fatal signals an error occurred after the query was sent to the server. It
|
||||
// closes the rows automatically.
|
||||
func (rows *Rows) Fatal(err error) {
|
||||
|
@ -169,10 +121,6 @@ func (rows *Rows) Next() bool {
|
|||
}
|
||||
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
rows.conn.rxReadyForQuery(r)
|
||||
rows.close()
|
||||
return false
|
||||
case dataRow:
|
||||
fieldCount := r.readInt16()
|
||||
if int(fieldCount) != len(rows.fields) {
|
||||
|
@ -183,7 +131,9 @@ func (rows *Rows) Next() bool {
|
|||
rows.mr = r
|
||||
return true
|
||||
case commandComplete:
|
||||
case bindComplete:
|
||||
rows.Close()
|
||||
return false
|
||||
|
||||
default:
|
||||
err = rows.conn.processContextFreeMsg(t, r)
|
||||
if err != nil {
|
||||
|
@ -441,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
|
|||
// be returned in an error state. So it is allowed to ignore the error returned
|
||||
// from Query and handle it in *Rows.
|
||||
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
c.lastActivityTime = time.Now()
|
||||
|
||||
rows := c.getRows(sql, args)
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
rows.abort(err)
|
||||
return rows, err
|
||||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.Prepare("", sql)
|
||||
if err != nil {
|
||||
rows.abort(err)
|
||||
return rows, rows.err
|
||||
}
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
rows.fields = ps.FieldDescriptions
|
||||
err := c.sendPreparedQuery(ps, args...)
|
||||
if err != nil {
|
||||
rows.abort(err)
|
||||
}
|
||||
return rows, rows.err
|
||||
return c.QueryContext(context.Background(), sql, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
|
||||
|
@ -492,3 +417,51 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
|||
rows, _ := c.Query(sql, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.lastActivityTime = time.Now()
|
||||
|
||||
rows = c.getRows(sql, args)
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.PrepareExContext(ctx, "", sql, nil)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, rows.err
|
||||
}
|
||||
}
|
||||
rows.sql = ps.SQL
|
||||
rows.fields = ps.FieldDescriptions
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
err = c.sendPreparedQuery(ps, args...)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
err = c.termContext(err)
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||
rows, _ := c.QueryContext(ctx, sql, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
|
163
query_test.go
163
query_test.go
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -1412,3 +1413,165 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
|
|||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryContextSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, "select 42::integer")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result, rowCount int
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatal(rows.Err())
|
||||
}
|
||||
|
||||
if rowCount != 1 {
|
||||
t.Fatalf("Expected 1 row, got %d", rowCount)
|
||||
}
|
||||
if result != 42 {
|
||||
t.Fatalf("Expected result 42, got %d", result)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result, rowCount int
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" {
|
||||
t.Fatalf("Expected division by zero error, but got %v", rows.Err())
|
||||
}
|
||||
|
||||
if rowCount != 9 {
|
||||
t.Fatalf("Expected 9 rows, got %d", rowCount)
|
||||
}
|
||||
if result != 10 {
|
||||
t.Fatalf("Expected result 10, got %d", result)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
|
||||
}
|
||||
|
||||
if rows.Err() != context.Canceled {
|
||||
t.Fatal("Expected context.Canceled error, got %v", rows.Err())
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowContextSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
var result int
|
||||
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != 42 {
|
||||
t.Fatalf("Expected result 42, got %d", result)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
var result int
|
||||
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
|
||||
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
|
||||
t.Fatalf("Expected division by zero error, but got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
var result []byte
|
||||
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
|
||||
if err != context.Canceled {
|
||||
t.Fatal("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
|
@ -289,7 +289,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *
|
|||
}
|
||||
|
||||
// Wait until there is a byte available before continuing onto the normal msg reading path
|
||||
_, err = rc.c.reader.Peek(1)
|
||||
_, err = rc.c.mr.reader.Peek(1)
|
||||
if err != nil {
|
||||
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
||||
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
||||
|
@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
|
|||
rows := rc.c.getRows(sql, nil)
|
||||
|
||||
if err := rc.c.lock(); err != nil {
|
||||
rows.abort(err)
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
err := rc.c.sendSimpleQuery(sql)
|
||||
if err != nil {
|
||||
rows.abort(err)
|
||||
rows.Fatal(err)
|
||||
}
|
||||
|
||||
var t byte
|
||||
|
@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
|
|||
// only Oids. Not much we can do about this.
|
||||
default:
|
||||
if e := rc.c.processContextFreeMsg(t, r); e != nil {
|
||||
rows.abort(e)
|
||||
rows.Fatal(e)
|
||||
return rows, e
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
package stdlib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
|
@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
|||
return c.queryPrepared("", argsV)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
ps, err := c.conn.PrepareExContext(ctx, "", query, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return c.queryPreparedContext(ctx, "", argsV)
|
||||
}
|
||||
|
||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
|
@ -226,6 +242,22 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
|
|||
return &Rows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
args := namedValueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.QueryContext(ctx, name, args...)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Rows{rows: rows}, nil
|
||||
}
|
||||
|
||||
// Anything that isn't a database/sql compatible type needs to be forced to
|
||||
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
||||
// (e.g. []int32)
|
||||
|
@ -318,6 +350,18 @@ func valueToInterface(argsV []driver.Value) []interface{} {
|
|||
return args
|
||||
}
|
||||
|
||||
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
|
||||
args := make([]interface{}, 0, len(argsV))
|
||||
for _, v := range argsV {
|
||||
if v.Value != nil {
|
||||
args = append(args, v.Value.(interface{}))
|
||||
} else {
|
||||
args = append(args, nil)
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
type Tx struct {
|
||||
conn *pgx.Conn
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package pgx_test
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) {
|
|||
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
|
||||
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
|
||||
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
|
||||
{"canceledQueryContext", canceledQueryContext},
|
||||
{"canceledExecContext", canceledExecContext},
|
||||
}
|
||||
|
||||
var timer *time.Timer
|
||||
|
@ -63,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
|
|||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn(pool, n)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
errChan <- fmt.Errorf("%s: %v", action.name, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
|
|||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
return errors.New("should never receive row")
|
||||
}
|
||||
|
||||
if rows.Err() != context.Canceled {
|
||||
return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
|
||||
if err != context.Canceled {
|
||||
return fmt.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue