diff --git a/CHANGELOG.md b/CHANGELOG.md index 2755a4ca..32acfdda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ pgconn now supports pipeline mode. `*PgConn.ReceiveResults` removed. Use pipeline mode instead. +`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error. + ## pgxpool `Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. diff --git a/pgconn/errors.go b/pgconn/errors.go index 4254535e..3c54bbec 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -19,7 +19,7 @@ func SafeToRetry(err error) bool { } // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a -// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. +// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { var timeoutErr *errTimeout return errors.As(err, &timeoutErr) @@ -106,11 +106,16 @@ func (e *parseConfigError) Unwrap() error { return e.err } -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return &errTimeout{err: ctx.Err()} +func normalizeTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() { + if ctx.Err() == context.Canceled { + // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. + return context.Canceled + } else if ctx.Err() == context.DeadlineExceeded { + return &errTimeout{err: ctx.Err()} + } else { + return &errTimeout{err: err} + } } return err } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 44de2897..59fa35c6 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -255,11 +255,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) netConn, err := config.DialFunc(ctx, network, address) if err != nil { - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - err = &errTimeout{err: err} - } - return nil, &connectError{config: config, msg: "dial error", err: err} + return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } nbNetConn := nbconn.NewNetConn(netConn, false) @@ -314,7 +310,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} + return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -448,7 +444,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa if err != nil { err = &pgconnError{ msg: "receive message failed", - err: preferContextOverNetTimeoutError(ctx, err), + err: normalizeTimeoutError(ctx, err), safeToRetry: true} } return msg, err @@ -794,7 +790,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -907,7 +903,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return normalizeTimeoutError(ctx, err) } switch msg.(type) { @@ -1106,7 +1102,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1203,7 +1199,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co break } pgConn.asyncClose() - return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1238,7 +1234,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) + return CommandTag{}, normalizeTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1281,7 +1277,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = normalizeTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.asyncClose() return nil, mrr.err @@ -1497,7 +1493,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - err = preferContextOverNetTimeoutError(rr.ctx, err) + err = normalizeTimeoutError(rr.ctx, err) rr.concludeCommand(CommandTag{}, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true @@ -1814,7 +1810,7 @@ func (p *Pipeline) Flush() error { err := p.conn.frontend.Flush() if err != nil { - err = preferContextOverNetTimeoutError(p.ctx, err) + err = normalizeTimeoutError(p.ctx, err) p.conn.asyncClose() @@ -1901,7 +1897,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { msg, err := p.conn.receiveMessage() if err != nil { p.conn.asyncClose() - return nil, preferContextOverNetTimeoutError(p.ctx, err) + return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) {