diff --git a/conn.go b/conn.go index 5ede5944..f91929c5 100644 --- a/conn.go +++ b/conn.go @@ -93,7 +93,9 @@ type Conn struct { status int32 // One of connStatus* constants causeOfDeath error - readyForQuery bool // can the connection be used to send a query + readyForQuery bool // connection has received ReadyForQuery message since last query was sent + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} // context support ctxInProgress bool @@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.channels = make(map[string]struct{}) atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() + c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) @@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.prepareEx(name, sql, opts) + return c.PrepareExContext(context.Background(), name, sql, opts) } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + err = c.initContext(ctx) if err != nil { return nil, err @@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // Deallocate released a prepared statement -func (c *Conn) Deallocate(name string) (err error) { +func (c *Conn) Deallocate(name string) error { + return c.deallocateContext(context.Background(), name) +} + +// TODO - consider making this public +func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + err = c.initContext(ctx) + if err != nil { + return err + } + defer func() { + err = c.termContext(err) + }() + if err := c.ensureConnectionReadyForQuery(); err != nil { return err } @@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + ctx, cancelFn := context.WithTimeout(context.Background(), timeout) + if err := c.waitForPreviousCancelQuery(ctx); err != nil { + cancelFn() + return nil, err + } + cancelFn() + if err := c.ensureConnectionReadyForQuery(); err != nil { return nil, err } @@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string { // ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 -func (c *Conn) cancelQuery() error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) - if err != nil { - return err +func (c *Conn) cancelQuery() { + if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { + panic("cancelQuery when cancelQueryInProgress") } - defer cancelConn.Close() - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) - _, err = cancelConn.Write(buf) - return err + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + return + } + + doCancel := func() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + // If server doesn't process cancellation request in bounded time then abort. + err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + return err + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + } + + return nil + } + + go func() { + err := doCancel() + if err != nil { + c.Close() // Something is very wrong. Terminate the connection. + } + c.cancelQueryCompleted <- struct{}{} + }() } func (c *Conn) Ping() error { @@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error { } func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return "", err + } + err = c.initContext(ctx) if err != nil { return "", err @@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error { select { case err = <-c.closedChan: - if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil { - c.Close() // Close connection if unable to disable deadline - } if opErr == nil { err = nil } @@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) { select { case <-ctx.Done(): c.cancelQuery() - if err := c.conn.SetDeadline(time.Now()); err != nil { - c.Close() // Close connection if unable to set deadline - } c.closedChan <- ctx.Err() case <-c.doneChan: } } +func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { + if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { + return nil + } + + select { + case <-c.cancelQueryCompleted: + atomic.StoreInt32(&c.cancelQueryInProgress, 0) + if err := c.conn.SetDeadline(time.Time{}); err != nil { + c.Close() // Close connection if unable to disable deadline + return err + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + func (c *Conn) ensureConnectionReadyForQuery() error { for !c.readyForQuery { t, r, err := c.rxMsg() diff --git a/query.go b/query.go index aa664649..dd7aafb0 100644 --- a/query.go +++ b/query.go @@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + c.lastActivityTime = time.Now() rows = c.getRows(sql, args) diff --git a/stress_test.go b/stress_test.go index d22d9d6b..72d48a5c 100644 --- a/stress_test.go +++ b/stress_test.go @@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) { action := actions[rand.Intn(len(actions))] err := action.fn(pool, n) if err != nil { - errChan <- err + errChan <- fmt.Errorf("%s: %v", action.name, err) break } } @@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { cancelFunc() }() - rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") + rows, err := pool.QueryContext(ctx, "select pg_sleep(2)") if err == context.Canceled { return nil } else if err != nil { - return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) + return fmt.Errorf("Only allowed error is context.Canceled, got %v", err) } for rows.Next() { - return errors.New("canceledQueryContext: should never receive row") + return errors.New("should never receive row") } if rows.Err() != context.Canceled { - return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) + return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err()) } return nil @@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { cancelFunc() }() - _, err := pool.ExecContext(ctx, "select pg_sleep(5)") + _, err := pool.ExecContext(ctx, "select pg_sleep(2)") if err != context.Canceled { - return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) + return fmt.Errorf("Expected context.Canceled error, got %v", err) } return nil