Fix context query cancellation

Previous commits had a race condition due to not waiting for the PostgreSQL
server to close the cancel query connection. This made it possible for the
cancel request to impact a subsequent query on the same connection. This
commit sets a flag that a cancel request was made and blocks until the
PostgreSQL server closes the cancel connection.
context
Jack Christensen 2017-02-11 19:53:18 -06:00
parent deac6564ee
commit 048a75406f
3 changed files with 118 additions and 29 deletions

128
conn.go
View File

@ -93,7 +93,9 @@ type Conn struct {
status int32 // One of connStatus* constants
causeOfDeath error
readyForQuery bool // can the connection be used to send a query
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
cancelQueryInProgress int32
cancelQueryCompleted chan struct{}
// context support
ctxInProgress bool
@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.channels = make(map[string]struct{})
atomic.StoreInt32(&c.status, connStatusIdle)
c.lastActivityTime = time.Now()
c.cancelQueryCompleted = make(chan struct{}, 1)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
return c.prepareEx(name, sql, opts)
return c.PrepareExContext(context.Background(), name, sql, opts)
}
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
err = c.initContext(ctx)
if err != nil {
return nil, err
@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
}
// Deallocate released a prepared statement
func (c *Conn) Deallocate(name string) (err error) {
func (c *Conn) Deallocate(name string) error {
return c.deallocateContext(context.Background(), name)
}
// TODO - consider making this public
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return err
}
err = c.initContext(ctx)
if err != nil {
return err
}
defer func() {
err = c.termContext(err)
}()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
return notification, nil
}
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
cancelFn()
return nil, err
}
cancelFn()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string {
// ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
func (c *Conn) cancelQuery() error {
network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
if err != nil {
return err
func (c *Conn) cancelQuery() {
if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
panic("cancelQuery when cancelQueryInProgress")
}
defer cancelConn.Close()
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
_, err = cancelConn.Write(buf)
return err
if err := c.conn.SetDeadline(time.Now()); err != nil {
c.Close() // Close connection if unable to set deadline
return
}
doCancel := func() error {
network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
if err != nil {
return err
}
defer cancelConn.Close()
// If server doesn't process cancellation request in bounded time then abort.
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
if err != nil {
return err
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
}
return nil
}
go func() {
err := doCancel()
if err != nil {
c.Close() // Something is very wrong. Terminate the connection.
}
c.cancelQueryCompleted <- struct{}{}
}()
}
func (c *Conn) Ping() error {
@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error {
}
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return "", err
}
err = c.initContext(ctx)
if err != nil {
return "", err
@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error {
select {
case err = <-c.closedChan:
if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil {
c.Close() // Close connection if unable to disable deadline
}
if opErr == nil {
err = nil
}
@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) {
select {
case <-ctx.Done():
c.cancelQuery()
if err := c.conn.SetDeadline(time.Now()); err != nil {
c.Close() // Close connection if unable to set deadline
}
c.closedChan <- ctx.Err()
case <-c.doneChan:
}
}
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
return nil
}
select {
case <-c.cancelQueryCompleted:
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
if err := c.conn.SetDeadline(time.Time{}); err != nil {
c.Close() // Close connection if unable to disable deadline
return err
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery {
t, r, err := c.rxMsg()

View File

@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
}
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
c.lastActivityTime = time.Now()
rows = c.getRows(sql, args)

View File

@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
action := actions[rand.Intn(len(actions))]
err := action.fn(pool, n)
if err != nil {
errChan <- err
errChan <- fmt.Errorf("%s: %v", action.name, err)
break
}
}
@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
cancelFunc()
}()
rows, err := pool.QueryContext(ctx, "select pg_sleep(5)")
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
if err == context.Canceled {
return nil
} else if err != nil {
return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err)
return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
}
for rows.Next() {
return errors.New("canceledQueryContext: should never receive row")
return errors.New("should never receive row")
}
if rows.Err() != context.Canceled {
return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err())
return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
}
return nil
@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
cancelFunc()
}()
_, err := pool.ExecContext(ctx, "select pg_sleep(5)")
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
if err != context.Canceled {
return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err)
return fmt.Errorf("Expected context.Canceled error, got %v", err)
}
return nil