diff --git a/.travis.yml b/.travis.yml index d9ea43b0..9ae8d963 100644 --- a/.travis.yml +++ b/.travis.yml @@ -51,6 +51,8 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake + - go get -u golang.org/x/net/context + - go get -u github.com/jackc/pgmock/pgmsg script: - go test -v -race -short ./... diff --git a/conn-lock-todo.txt b/conn-lock-todo.txt new file mode 100644 index 00000000..ab5eac95 --- /dev/null +++ b/conn-lock-todo.txt @@ -0,0 +1,11 @@ +Extract all locking state into a separate struct that will encapsulate locking and state change behavior. + +This struct should add or subsume at least the following: +* alive +* closingLock +* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one) +* busy +* lock/unlock +* Tx in-progress +* Rows in-progress +* ConnPool checked-out or checked-in - maybe include reference to conn pool diff --git a/conn.go b/conn.go index 75792408..d1205636 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "errors" "fmt" + "golang.org/x/net/context" "io" "net" "net/url" @@ -17,9 +18,17 @@ import ( "regexp" "strconv" "strings" + "sync/atomic" "time" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -39,13 +48,28 @@ type ConnConfig struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) } +func (cc *ConnConfig) networkAddress() (network, address string) { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + // See if host is a valid path, if yes connect with a socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } + + return network, address +} + // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Use ConnPool to manage access to multiple database connections from multiple // goroutines. type Conn struct { - conn net.Conn // the underlying TCP or unix domain socket connection - lastActivityTime time.Time // the last time the connection was used - reader *bufio.Reader // buffered reader to improve read performance + conn net.Conn // the underlying TCP or unix domain socket connection + lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf pid int32 // backend pid @@ -57,17 +81,26 @@ type Conn struct { preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification - alive bool - causeOfDeath error logger Logger logLevel int mr msgReader fp *fastpath pgsqlAfInet *byte pgsqlAfInet6 *byte - busy bool poolResetCount int preallocatedRows []Rows + + status int32 // One of connStatus* constants + causeOfDeath error + + readyForQuery bool // connection has received ReadyForQuery message since last query was sent + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} + + // context support + ctxInProgress bool + doneChan chan struct{} + closedChan chan error } // PreparedStatement is a description of a prepared statement @@ -194,17 +227,7 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsql } } - network := "tcp" - address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(c.config.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = c.config.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) - } - } + network, address := c.config.networkAddress() if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } @@ -238,15 +261,18 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - c.alive = true + atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() + c.cancelQueryCompleted = make(chan struct{}, 1) + c.doneChan = make(chan struct{}) + c.closedChan = make(chan error) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -257,8 +283,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.reader = bufio.NewReader(c.conn) - c.mr.reader = c.reader + c.mr.reader = bufio.NewReader(c.conn) msg := newStartupMessage() @@ -389,14 +414,17 @@ func (c *Conn) PID() int32 { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - if !c.IsAlive() { - return nil + for { + status := atomic.LoadInt32(&c.status) + if status < connStatusIdle { + return nil + } + if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) { + break + } } - wbuf := newWriteBuf(c, 'X') - wbuf.closeMsg() - - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) c.die(errors.New("Closed")) if c.shouldLog(LogLevelInfo) { @@ -614,12 +642,36 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + return c.PrepareExContext(context.Background(), name, sql, opts) +} + +func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + + err = c.initContext(ctx) + if err != nil { + return nil, err + } + + ps, err = c.prepareEx(name, sql, opts) + err = c.termContext(err) + return ps, err +} + +func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil } } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + if c.shouldLog(LogLevelError) { defer func() { if err != nil { @@ -659,6 +711,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared c.die(err) return nil, err } + c.readyForQuery = false ps = &PreparedStatement{Name: name, SQL: sql} @@ -673,7 +726,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } switch t { - case parseComplete: case parameterDescription: ps.ParameterOIDs = c.rxParameterDescription(r) @@ -687,7 +739,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.FieldDescriptions[i].DataTypeName = t.Name ps.FieldDescriptions[i].FormatCode = t.DefaultFormat } - case noData: case readyForQuery: c.rxReadyForQuery(r) @@ -705,7 +756,29 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // Deallocate released a prepared statement -func (c *Conn) Deallocate(name string) (err error) { +func (c *Conn) Deallocate(name string) error { + return c.deallocateContext(context.Background(), name) +} + +// TODO - consider making this public +func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + err = c.initContext(ctx) + if err != nil { + return err + } + defer func() { + err = c.termContext(err) + }() + + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + delete(c.preparedStatements, name) // close @@ -776,6 +849,17 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + ctx, cancelFn := context.WithTimeout(context.Background(), timeout) + if err := c.waitForPreviousCancelQuery(ctx); err != nil { + cancelFn() + return nil, err + } + cancelFn() + + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + stopTime := time.Now().Add(timeout) for { @@ -835,7 +919,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.reader.Peek(1) + _, err = c.mr.reader.Peek(1) if err != nil { c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { @@ -868,7 +952,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - return c.alive + return atomic.LoadInt32(&c.status) >= connStatusIdle } func (c *Conn) CauseOfDeath() error { @@ -883,6 +967,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { } func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } if len(args) == 0 { wbuf := newWriteBuf(c, 'Q') @@ -894,6 +981,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { c.die(err) return err } + c.readyForQuery = false return nil } @@ -911,6 +999,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + // bind wbuf := newWriteBuf(c, 'B') wbuf.WriteByte(0) @@ -958,6 +1050,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} if err != nil { c.die(err) } + c.readyForQuery = false return err } @@ -965,91 +1058,52 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - if err = c.lock(); err != nil { - return commandTag, err - } - - startTime := time.Now() - c.lastActivityTime = startTime - - defer func() { - if err == nil { - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) - } - } else { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) - } - } - - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr - } - }() - - if err = c.sendQuery(sql, arguments...); err != nil { - return - } - - var softErr error - - for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() - if err != nil { - return commandTag, err - } - - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return commandTag, softErr - case rowDescription: - case dataRow: - case bindComplete: - case commandComplete: - commandTag = CommandTag(r.readCString()) - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } - } + return c.ExecContext(context.Background(), sql, arguments...) } // Processes messages that are not exclusive to one context such as -// authentication or query response. The response to these messages -// is the same regardless of when they occur. +// authentication or query response. The response to these messages is the same +// regardless of when they occur. It also ignores messages that are only +// meaningful in a given context. These messages can occur due to a context +// deadline interrupting message processing. For example, an interrupted query +// may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { switch t { - case 'S': - c.rxParameterStatus(r) - return nil + case bindComplete: + case commandComplete: + case dataRow: + case emptyQueryResponse: case errorResponse: return c.rxErrorResponse(r) + case noData: case noticeResponse: - return nil - case emptyQueryResponse: - return nil case notificationResponse: c.rxNotificationResponse(r) - return nil + case parameterDescription: + case parseComplete: + case readyForQuery: + c.rxReadyForQuery(r) + case rowDescription: + case 'S': + c.rxParameterStatus(r) + default: return fmt.Errorf("Received unknown message type: %c", t) } + + return nil } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { - if !c.alive { + if atomic.LoadInt32(&c.status) < connStatusIdle { return 0, nil, ErrDeadConn } t, err = c.mr.rxMsg() if err != nil { - c.die(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + c.die(err) + } } c.lastActivityTime = time.Now() @@ -1150,6 +1204,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) { } func (c *Conn) rxReadyForQuery(r *msgReader) { + c.readyForQuery = true c.txStatus = r.readByte() } @@ -1230,25 +1285,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) { } func (c *Conn) die(err error) { - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if c.busy { - return ErrConnBusy + if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) { + return nil } - c.busy = true - return nil + return ErrConnBusy } func (c *Conn) unlock() error { - if !c.busy { - return errors.New("unlock conn that is not busy") + if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) { + return nil } - c.busy = false - return nil + return errors.New("unlock conn that is not busy") } func (c *Conn) shouldLog(lvl int) bool { @@ -1286,3 +1339,229 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) { func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } + +// cancelQuery sends a cancel request to the PostgreSQL server. It returns an +// error if unable to deliver the cancel request, but lack of an error does not +// ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See +// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 +func (c *Conn) cancelQuery() { + if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { + panic("cancelQuery when cancelQueryInProgress") + } + + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + return + } + + doCancel := func() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + // If server doesn't process cancellation request in bounded time then abort. + err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + return err + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + } + + return nil + } + + go func() { + err := doCancel() + if err != nil { + c.Close() // Something is very wrong. Terminate the connection. + } + c.cancelQueryCompleted <- struct{}{} + }() +} + +func (c *Conn) Ping() error { + return c.PingContext(context.Background()) +} + +func (c *Conn) PingContext(ctx context.Context) error { + _, err := c.ExecContext(ctx, ";") + return err +} + +func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return "", err + } + + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return commandTag, err + } + + startTime := time.Now() + c.lastActivityTime = startTime + + defer func() { + if err == nil { + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) + } + } else { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) + } + } + + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() + + if err = c.sendQuery(sql, arguments...); err != nil { + return + } + + var softErr error + + for { + var t byte + var r *msgReader + t, r, err = c.rxMsg() + if err != nil { + return commandTag, err + } + + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return commandTag, softErr + case commandComplete: + commandTag = CommandTag(r.readCString()) + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } + } + } + + return commandTag, err +} + +func (c *Conn) initContext(ctx context.Context) error { + if c.ctxInProgress { + return errors.New("ctx already in progress") + } + + if ctx.Done() == nil { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + c.ctxInProgress = true + + go c.contextHandler(ctx) + + return nil +} + +func (c *Conn) termContext(opErr error) error { + if !c.ctxInProgress { + return opErr + } + + var err error + + select { + case err = <-c.closedChan: + if opErr == nil { + err = nil + } + case c.doneChan <- struct{}{}: + err = opErr + } + + c.ctxInProgress = false + return err +} + +func (c *Conn) contextHandler(ctx context.Context) { + select { + case <-ctx.Done(): + c.cancelQuery() + c.closedChan <- ctx.Err() + case <-c.doneChan: + } +} + +func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { + if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { + return nil + } + + select { + case <-c.cancelQueryCompleted: + atomic.StoreInt32(&c.cancelQueryInProgress, 0) + if err := c.conn.SetDeadline(time.Time{}); err != nil { + c.Close() // Close connection if unable to disable deadline + return err + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *Conn) ensureConnectionReadyForQuery() error { + for !c.readyForQuery { + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case errorResponse: + pgErr := c.rxErrorResponse(r) + if pgErr.Severity == "FATAL" { + return pgErr + } + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/conn_pool.go b/conn_pool.go index 6614c4f0..9701f170 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -2,6 +2,7 @@ package pgx import ( "errors" + "golang.org/x/net/context" "sync" "time" ) @@ -181,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Release gives up use of a connection. func (p *ConnPool) Release(conn *Conn) { + if conn.ctxInProgress { + panic("should never release when context is in progress") + } + if conn.txStatus != 'I' { conn.Exec("rollback") } @@ -357,6 +362,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } +func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + var c *Conn + if c, err = p.Acquire(); err != nil { + return + } + defer p.Release(c) + + return c.ExecContext(ctx, sql, arguments...) +} + // Query acquires a connection and delegates the call to that connection. When // *Rows are closed, the connection is released automatically. func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { @@ -377,6 +392,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, nil } +func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *Rows, build one with the error + return &Rows{closed: true, err: err}, err + } + + rows, err := c.QueryContext(ctx, sql, args...) + if err != nil { + p.Release(c) + return rows, err + } + + rows.AfterClose(p.rowsAfterClose) + + return rows, nil +} + // QueryRow acquires a connection and delegates the call to that connection. The // connection is released automatically after Scan is called on the returned // *Row. @@ -385,6 +418,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := p.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} + // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { diff --git a/conn_test.go b/conn_test.go index ecd7e88d..a8398507 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "crypto/tls" "fmt" + "golang.org/x/net/context" "net" "os" "reflect" @@ -816,6 +817,64 @@ func TestExecFailure(t *testing.T) { } } +func TestExecContextWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);") + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecContext: %v", commandTag) + } +} + +func TestExecContextFailureWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecContext(ctx, "selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err()) + } +} + +func TestExecContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + _, err := conn.ExecContext(ctx, "select pg_sleep(60)") + if err != context.Canceled { + t.Fatal("Expected context.Canceled err, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestPrepare(t *testing.T) { t.Parallel() diff --git a/context-todo.txt b/context-todo.txt new file mode 100644 index 00000000..b5a20d0a --- /dev/null +++ b/context-todo.txt @@ -0,0 +1,12 @@ +Add more testing +- stress test style +- pgmock + +Add documentation + +Add PrepareContext +Add context methods to ConnPool +Add context methods to Tx +Add context support database/sql + +Benchmark - possibly cache done channel on Conn diff --git a/copy_to.go b/copy_to.go index 91292bb0..dd70ada3 100644 --- a/copy_to.go +++ b/copy_to.go @@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() { ct.conn.rxReadyForQuery(r) close(ct.readerErrChan) return - case commandComplete: case errorResponse: ct.readerErrChan <- ct.conn.rxErrorResponse(r) default: diff --git a/fastpath.go b/fastpath.go index 28f88d5e..af055e56 100644 --- a/fastpath.go +++ b/fastpath.go @@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg { } func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) { + if err := f.cn.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + wbuf := newWriteBuf(f.cn, 'F') // function call wbuf.WriteInt32(int32(oid)) // function object id wbuf.WriteInt16(1) // # of argument format codes diff --git a/helper_test.go b/helper_test.go index eff731e8..21f86de5 100644 --- a/helper_test.go +++ b/helper_test.go @@ -21,7 +21,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio return conn } - func closeConn(t testing.TB, conn *pgx.Conn) { err := conn.Close() if err != nil { diff --git a/msg_reader.go b/msg_reader.go index 0c3c23b8..f507c198 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "io" + "net" ) // msgReader is a helper that reads values from a PostgreSQL message. @@ -16,11 +17,6 @@ type msgReader struct { shouldLog func(lvl int) bool } -// Err returns any error that the msgReader has experienced -func (r *msgReader) Err() error { - return r.err -} - // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { @@ -40,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) { r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) } - _, err := r.reader.Discard(int(r.msgBytesRemaining)) + n, err := r.reader.Discard(int(r.msgBytesRemaining)) + r.msgBytesRemaining -= int32(n) if err != nil { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } } b, err := r.reader.Peek(5) if err != nil { - r.fatal(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } + msgType := b[0] - r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 + payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + + // Try to preload bufio.Reader with entire message + b, err = r.reader.Peek(5 + int(payloadSize)) + if err != nil && err != bufio.ErrBufferFull { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } + return 0, err + } + + r.msgBytesRemaining = payloadSize r.reader.Discard(5) + return msgType, nil } diff --git a/msg_reader_test.go b/msg_reader_test.go new file mode 100644 index 00000000..2bbd53c9 --- /dev/null +++ b/msg_reader_test.go @@ -0,0 +1,189 @@ +package pgx + +import ( + "bufio" + "net" + "testing" + "time" + + "github.com/jackc/pgmock/pgmsg" +) + +func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { + t.Parallel() + + tests := []struct { + msgType byte + payloadSize int32 + buffered bool + }{ + {1, 50, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 24000, false}, + {9, 4000, true}, + {1, 1500, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 14000, false}, + {9, 0, true}, + {1, 500, true}, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for _, tt := range tests { + _, err = conn.Write([]byte{tt.msgType}) + if err != nil { + t.Fatal(err) + } + + _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, int(tt.payloadSize)) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + for i, tt := range tests { + msgType, err := mr.rxMsg() + if err != nil { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + + if msgType != tt.msgType { + t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) + } + + if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { + t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) + } + } +} + +func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + testCount := 10000 + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for i := 0; i < testCount; i++ { + msgType := byte(i) + + _, err = conn.Write([]byte{msgType}) + if err != nil { + t.Fatal(err) + } + + msgSize := i % 4000 + + _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, msgSize) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + + i := 0 + for { + msgType, err := mr.rxMsg() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + continue + } else { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + } + + expectedMsgType := byte(i) + if msgType != expectedMsgType { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) + } + + expectedMsgSize := i % 4000 + payload := mr.readBytes(mr.msgBytesRemaining) + if mr.err != nil { + t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) + } + if len(payload) != expectedMsgSize { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) + } + + i++ + if i == testCount { + break + } + } +} diff --git a/query.go b/query.go index 778bc9cc..efb039d5 100644 --- a/query.go +++ b/query.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "golang.org/x/net/context" "time" ) @@ -55,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription { return rows.fields } -func (rows *Rows) close() { +// Close closes the rows, making the connection ready for use again. It is safe +// to call Close after rows is already closed. +func (rows *Rows) Close() { if rows.closed { return } @@ -67,6 +70,8 @@ func (rows *Rows) close() { rows.closed = true + rows.err = rows.conn.termContext(rows.err) + if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() @@ -81,63 +86,10 @@ func (rows *Rows) close() { } } -func (rows *Rows) readUntilReadyForQuery() { - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.close() - return - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return - case rowDescription: - case dataRow: - case commandComplete: - case bindComplete: - case errorResponse: - err = rows.conn.rxErrorResponse(r) - if rows.err == nil { - rows.err = err - } - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.close() - return - } - } - } -} - -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { - if rows.closed { - return - } - rows.readUntilReadyForQuery() - rows.close() -} - func (rows *Rows) Err() error { return rows.err } -// abort signals that the query was not successfully sent to the server. -// This differs from Fatal in that it is not necessary to readUntilReadyForQuery -func (rows *Rows) abort(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.close() -} - // Fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. func (rows *Rows) Fatal(err error) { @@ -169,10 +121,6 @@ func (rows *Rows) Next() bool { } switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return false case dataRow: fieldCount := r.readInt16() if int(fieldCount) != len(rows.fields) { @@ -183,7 +131,9 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - case bindComplete: + rows.Close() + return false + default: err = rows.conn.processContextFreeMsg(t, r) if err != nil { @@ -441,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) { // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - c.lastActivityTime = time.Now() - - rows := c.getRows(sql, args) - - if err := c.lock(); err != nil { - rows.abort(err) - return rows, err - } - rows.unlockConn = true - - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.Prepare("", sql) - if err != nil { - rows.abort(err) - return rows, rows.err - } - } - rows.sql = ps.SQL - rows.fields = ps.FieldDescriptions - err := c.sendPreparedQuery(ps, args...) - if err != nil { - rows.abort(err) - } - return rows, rows.err + return c.QueryContext(context.Background(), sql, args...) } func (c *Conn) getRows(sql string, args []interface{}) *Rows { @@ -492,3 +417,51 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { rows, _ := c.Query(sql, args...) return (*Row)(rows) } + +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + + c.lastActivityTime = time.Now() + + rows = c.getRows(sql, args) + + if err := c.lock(); err != nil { + rows.Fatal(err) + return rows, err + } + rows.unlockConn = true + + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.PrepareExContext(ctx, "", sql, nil) + if err != nil { + rows.Fatal(err) + return rows, rows.err + } + } + rows.sql = ps.SQL + rows.fields = ps.FieldDescriptions + + err = c.initContext(ctx) + if err != nil { + rows.Fatal(err) + return rows, err + } + + err = c.sendPreparedQuery(ps, args...) + if err != nil { + rows.Fatal(err) + err = c.termContext(err) + } + + return rows, err +} + +func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := c.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} diff --git a/query_test.go b/query_test.go index 15f57e49..83c2f9c1 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql" "fmt" + "golang.org/x/net/context" "strings" "testing" "time" @@ -1412,3 +1413,165 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 42::integer") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + if rowCount != 1 { + t.Fatalf("Expected 1 row, got %d", rowCount) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextErrorWhileReceivingRows(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", rows.Err()) + } + + if rowCount != 9 { + t.Fatalf("Expected 9 rows, got %d", rowCount) + } + if result != 10 { + t.Fatalf("Expected result 10, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + rows, err := conn.QueryContext(ctx, "select pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") + } + + if rows.Err() != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", rows.Err()) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result) + if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + var result []byte + err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result) + if err != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/replication.go b/replication.go index 7b28d6b6..0acc9df9 100644 --- a/replication.go +++ b/replication.go @@ -289,7 +289,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r * } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.reader.Peek(1) + _, err = rc.c.mr.reader.Peek(1) if err != nil { rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { @@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows := rc.c.getRows(sql, nil) if err := rc.c.lock(); err != nil { - rows.abort(err) + rows.Fatal(err) return rows, err } rows.unlockConn = true err := rc.c.sendSimpleQuery(sql) if err != nil { - rows.abort(err) + rows.Fatal(err) } var t byte @@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // only Oids. Not much we can do about this. default: if e := rc.c.processContextFreeMsg(t, r); e != nil { - rows.abort(e) + rows.Fatal(e) return rows, e } } diff --git a/stdlib/sql.go b/stdlib/sql.go index 6e55996f..420b521e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -44,6 +44,7 @@ package stdlib import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return c.queryPrepared("", argsV) } +func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + ps, err := c.conn.PrepareExContext(ctx, "", query, nil) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPreparedContext(ctx, "", argsV) +} + func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn @@ -226,6 +242,22 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er return &Rows{rows: rows}, nil } +func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + rows, err := c.conn.QueryContext(ctx, name, args...) + if err != nil { + fmt.Println(err) + return nil, err + } + + return &Rows{rows: rows}, nil +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) @@ -318,6 +350,18 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } +func namedValueToInterface(argsV []driver.NamedValue) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v.Value != nil { + args = append(args, v.Value.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + type Tx struct { conn *pgx.Conn } diff --git a/stress_test.go b/stress_test.go index 150d13c8..72d48a5c 100644 --- a/stress_test.go +++ b/stress_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "errors" "fmt" + "golang.org/x/net/context" "math/rand" "testing" "time" @@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) { {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, + {"canceledQueryContext", canceledQueryContext}, + {"canceledExecContext", canceledExecContext}, } var timer *time.Timer @@ -63,7 +66,7 @@ func TestStressConnPool(t *testing.T) { action := actions[rand.Intn(len(actions))] err := action.fn(pool, n) if err != nil { - errChan <- err + errChan <- fmt.Errorf("%s: %v", action.name, err) break } } @@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { return tx.Commit() } + +func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + rows, err := pool.QueryContext(ctx, "select pg_sleep(2)") + if err == context.Canceled { + return nil + } else if err != nil { + return fmt.Errorf("Only allowed error is context.Canceled, got %v", err) + } + + for rows.Next() { + return errors.New("should never receive row") + } + + if rows.Err() != context.Canceled { + return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err()) + } + + return nil +} + +func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + _, err := pool.ExecContext(ctx, "select pg_sleep(2)") + if err != context.Canceled { + return fmt.Errorf("Expected context.Canceled error, got %v", err) + } + + return nil +}