mirror of https://github.com/jackc/pgx.git
Fix context query cancellation
Previous commits had a race condition due to not waiting for the PostgreSQL server to close the cancel query connection. This made it possible for the cancel request to impact a subsequent query on the same connection. This commit sets a flag that a cancel request was made and blocks until the PostgreSQL server closes the cancel connection.context
parent
deac6564ee
commit
048a75406f
128
conn.go
128
conn.go
|
@ -93,7 +93,9 @@ type Conn struct {
|
|||
status int32 // One of connStatus* constants
|
||||
causeOfDeath error
|
||||
|
||||
readyForQuery bool // can the connection be used to send a query
|
||||
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
||||
cancelQueryInProgress int32
|
||||
cancelQueryCompleted chan struct{}
|
||||
|
||||
// context support
|
||||
ctxInProgress bool
|
||||
|
@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.channels = make(map[string]struct{})
|
||||
atomic.StoreInt32(&c.status, connStatusIdle)
|
||||
c.lastActivityTime = time.Now()
|
||||
c.cancelQueryCompleted = make(chan struct{}, 1)
|
||||
c.doneChan = make(chan struct{})
|
||||
c.closedChan = make(chan error)
|
||||
|
||||
|
@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
|
||||
// concern for if the statement has already been prepared.
|
||||
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||
return c.prepareEx(name, sql, opts)
|
||||
return c.PrepareExContext(context.Background(), name, sql, opts)
|
||||
}
|
||||
|
||||
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
}
|
||||
|
||||
// Deallocate released a prepared statement
|
||||
func (c *Conn) Deallocate(name string) (err error) {
|
||||
func (c *Conn) Deallocate(name string) error {
|
||||
return c.deallocateContext(context.Background(), name)
|
||||
}
|
||||
|
||||
// TODO - consider making this public
|
||||
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||
return notification, nil
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
|
||||
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
|
||||
cancelFn()
|
||||
return nil, err
|
||||
}
|
||||
cancelFn()
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string {
|
|||
// ensure that the query was canceled. As specified in the documentation, there
|
||||
// is no way to be sure a query was canceled. See
|
||||
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
|
||||
func (c *Conn) cancelQuery() error {
|
||||
network, address := c.config.networkAddress()
|
||||
cancelConn, err := c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
func (c *Conn) cancelQuery() {
|
||||
if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
|
||||
panic("cancelQuery when cancelQueryInProgress")
|
||||
}
|
||||
defer cancelConn.Close()
|
||||
|
||||
buf := make([]byte, 16)
|
||||
binary.BigEndian.PutUint32(buf[0:4], 16)
|
||||
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
||||
_, err = cancelConn.Write(buf)
|
||||
return err
|
||||
if err := c.conn.SetDeadline(time.Now()); err != nil {
|
||||
c.Close() // Close connection if unable to set deadline
|
||||
return
|
||||
}
|
||||
|
||||
doCancel := func() error {
|
||||
network, address := c.config.networkAddress()
|
||||
cancelConn, err := c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancelConn.Close()
|
||||
|
||||
// If server doesn't process cancellation request in bounded time then abort.
|
||||
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 16)
|
||||
binary.BigEndian.PutUint32(buf[0:4], 16)
|
||||
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
||||
_, err = cancelConn.Write(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cancelConn.Read(buf)
|
||||
if err != io.EOF {
|
||||
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := doCancel()
|
||||
if err != nil {
|
||||
c.Close() // Something is very wrong. Terminate the connection.
|
||||
}
|
||||
c.cancelQueryCompleted <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Conn) Ping() error {
|
||||
|
@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error {
|
|||
|
||||
select {
|
||||
case err = <-c.closedChan:
|
||||
if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil {
|
||||
c.Close() // Close connection if unable to disable deadline
|
||||
}
|
||||
if opErr == nil {
|
||||
err = nil
|
||||
}
|
||||
|
@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancelQuery()
|
||||
if err := c.conn.SetDeadline(time.Now()); err != nil {
|
||||
c.Close() // Close connection if unable to set deadline
|
||||
}
|
||||
c.closedChan <- ctx.Err()
|
||||
case <-c.doneChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
|
||||
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.cancelQueryCompleted:
|
||||
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
|
||||
if err := c.conn.SetDeadline(time.Time{}); err != nil {
|
||||
c.Close() // Close connection if unable to disable deadline
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ensureConnectionReadyForQuery() error {
|
||||
for !c.readyForQuery {
|
||||
t, r, err := c.rxMsg()
|
||||
|
|
5
query.go
5
query.go
|
@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
|||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.lastActivityTime = time.Now()
|
||||
|
||||
rows = c.getRows(sql, args)
|
||||
|
|
|
@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
|
|||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn(pool, n)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
errChan <- fmt.Errorf("%s: %v", action.name, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
|
|||
cancelFunc()
|
||||
}()
|
||||
|
||||
rows, err := pool.QueryContext(ctx, "select pg_sleep(5)")
|
||||
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err)
|
||||
return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
return errors.New("canceledQueryContext: should never receive row")
|
||||
return errors.New("should never receive row")
|
||||
}
|
||||
|
||||
if rows.Err() != context.Canceled {
|
||||
return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err())
|
||||
return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
|
|||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := pool.ExecContext(ctx, "select pg_sleep(5)")
|
||||
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
|
||||
if err != context.Canceled {
|
||||
return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err)
|
||||
return fmt.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue