diff --git a/errors.go b/errors.go index 77adfcf0..5df851d5 100644 --- a/errors.go +++ b/errors.go @@ -18,15 +18,11 @@ func SafeToRetry(err error) bool { return false } -// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - - var netErr net.Error - return errors.As(err, &netErr) && netErr.Timeout() + var timeoutErr *ErrTimeout + return errors.As(err, &timeoutErr) } // PgError represents an error reported by the PostgreSQL server. See @@ -134,6 +130,32 @@ func (e *pgconnError) Unwrap() error { return e.err } +// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type ErrTimeout struct { + err error +} + +func (e *ErrTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *ErrTimeout) SafeToRetry() bool { + var ctxErr *contextAlreadyDoneError + if errors.As(e, &ctxErr) { + return ctxErr.SafeToRetry() + } + var netErr net.Error + if errors.As(e, &netErr) { + return netErr.Temporary() + } + return false +} + +func (e *ErrTimeout) Unwrap() error { + return e.err +} + type contextAlreadyDoneError struct { err error } @@ -150,6 +172,17 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } +// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its +// deadline passed, the returned error is also wrapped by `ErrTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + ctxErr := ctx.Err() + err = &contextAlreadyDoneError{err: ctxErr} + if ctxErr != nil { + err = &ErrTimeout{err: err} + } + return err +} + type writeError struct { err error safeToRetry bool diff --git a/pgconn.go b/pgconn.go index 197aad4a..74e24257 100644 --- a/pgconn.go +++ b/pgconn.go @@ -217,6 +217,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + err = &ErrTimeout{err: err} + } return nil, &connectError{config: config, msg: "dial error", err: err} } @@ -389,7 +393,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if ctx != context.Background() { select { case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} + return newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -421,7 +425,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -451,7 +455,8 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { pgConn.bufferingReceive = false // If a timeout error happened in the background try the read again. - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { msg, err = pgConn.frontend.Receive() } } else { @@ -460,8 +465,12 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() + } else if isNetErr && netErr.Timeout() { + err = &ErrTimeout{err: err} } return nil, err @@ -476,8 +485,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() + } else if isNetErr && netErr.Timeout() { + err = &ErrTimeout{err: err} } return nil, err @@ -745,7 +758,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -918,7 +931,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: @@ -964,7 +977,7 @@ func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: @@ -1058,7 +1071,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if ctx != context.Background() { select { case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) result.closed = true pgConn.unlock() return result @@ -1098,7 +1111,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1158,7 +1171,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1601,7 +1614,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: