diff --git a/conn.go b/conn.go index a8b61547..71dc286a 100644 --- a/conn.go +++ b/conn.go @@ -20,7 +20,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/pkg/errors" @@ -78,6 +77,7 @@ type ConnConfig struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) OnNotice NoticeHandler // Callback function called when a notice response is received. CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. + CustomCancel func(*Conn) error // Callback function used to override cancellation behavior // PreferSimpleProtocol disables implicit prepared statement usage. By default // pgx automatically uses the unnamed prepared statement for Query and @@ -133,8 +133,7 @@ type Conn struct { status byte // One of connStatus* constants causeOfDeath error - pendingReadyForQueryCount int // numer of ReadyForQuery messages expected - cancelQueryInProgress int32 + pendingReadyForQueryCount int // number of ReadyForQuery messages expected cancelQueryCompleted chan struct{} // context support @@ -309,7 +308,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) c.lastActivityTime = time.Now() - c.cancelQueryCompleted = make(chan struct{}, 1) + c.cancelQueryCompleted = make(chan struct{}) + close(c.cancelQueryCompleted) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) @@ -620,6 +620,14 @@ func (c *Conn) PID() uint32 { return c.pid } +// LocalAddr returns the underlying connection's local address +func (c *Conn) LocalAddr() (net.Addr, error) { + if !c.IsAlive() { + return nil, errors.New("connection not ready") + } + return c.conn.LocalAddr(), nil +} + // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { @@ -1656,59 +1664,67 @@ func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } +func doCancel(c *Conn) error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + // If server doesn't process cancellation request in bounded time then abort. + now := time.Now() + err = cancelConn.SetDeadline(now.Add(15 * time.Second)) + if err != nil { + return err + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + } + + return nil +} + // cancelQuery sends a cancel request to the PostgreSQL server. It returns an // error if unable to deliver the cancel request, but lack of an error does not // ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 func (c *Conn) cancelQuery() { - if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { - panic("cancelQuery when cancelQueryInProgress") - } - if err := c.conn.SetDeadline(time.Now()); err != nil { c.Close() // Close connection if unable to set deadline return } - doCancel := func() error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) - if err != nil { - return err - } - defer cancelConn.Close() - - // If server doesn't process cancellation request in bounded time then abort. - err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - return err - } - - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) - _, err = cancelConn.Write(buf) - if err != nil { - return err - } - - _, err = cancelConn.Read(buf) - if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) - } - - return nil + var cancelFn func(*Conn) error + completeCh := make(chan struct{}) + c.mux.Lock() + c.cancelQueryCompleted = completeCh + c.mux.Unlock() + if c.config.CustomCancel != nil { + cancelFn = c.config.CustomCancel + } else { + cancelFn = doCancel } go func() { - err := doCancel() + defer close(completeCh) + err := cancelFn(c) if err != nil { c.Close() // Something is very wrong. Terminate the connection. } - c.cancelQueryCompleted <- struct{}{} }() } @@ -1893,14 +1909,21 @@ func (c *Conn) contextHandler(ctx context.Context) { } } -func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { - if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { - return nil +// WaitUntilReady will return when the connection is ready for another query +func (c *Conn) WaitUntilReady(ctx context.Context) error { + err := c.waitForPreviousCancelQuery(ctx) + if err != nil { + return err } + return c.ensureConnectionReadyForQuery() +} +func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { + c.mux.Lock() + completeCh := c.cancelQueryCompleted + c.mux.Unlock() select { - case <-c.cancelQueryCompleted: - atomic.StoreInt32(&c.cancelQueryInProgress, 0) + case <-completeCh: if err := c.conn.SetDeadline(time.Time{}); err != nil { c.Close() // Close connection if unable to disable deadline return err