Implement timeout error

Signed-off-by: Michael Darr <michael.e.darr@gmail.com>
query-exec-mode
Michael Darr 2021-06-29 14:24:09 -04:00 committed by Jack Christensen
parent a123e5b4e5
commit c0b4d3bc05
2 changed files with 65 additions and 19 deletions

View File

@ -18,15 +18,11 @@ func SafeToRetry(err error) bool {
return false return false
} }
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool { func Timeout(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { var timeoutErr *ErrTimeout
return true return errors.As(err, &timeoutErr)
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
} }
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
@ -134,6 +130,32 @@ func (e *pgconnError) Unwrap() error {
return e.err return e.err
} }
// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type ErrTimeout struct {
err error
}
func (e *ErrTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *ErrTimeout) SafeToRetry() bool {
var ctxErr *contextAlreadyDoneError
if errors.As(e, &ctxErr) {
return ctxErr.SafeToRetry()
}
var netErr net.Error
if errors.As(e, &netErr) {
return netErr.Temporary()
}
return false
}
func (e *ErrTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct { type contextAlreadyDoneError struct {
err error err error
} }
@ -150,6 +172,17 @@ func (e *contextAlreadyDoneError) Unwrap() error {
return e.err return e.err
} }
// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its
// deadline passed, the returned error is also wrapped by `ErrTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
ctxErr := ctx.Err()
err = &contextAlreadyDoneError{err: ctxErr}
if ctxErr != nil {
err = &ErrTimeout{err: err}
}
return err
}
type writeError struct { type writeError struct {
err error err error
safeToRetry bool safeToRetry bool

View File

@ -217,6 +217,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address) pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil { if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
err = &ErrTimeout{err: err}
}
return nil, &connectError{config: config, msg: "dial error", err: err} return nil, &connectError{config: config, msg: "dial error", err: err}
} }
@ -389,7 +393,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()} return newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -421,7 +425,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -451,7 +455,8 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
pgConn.bufferingReceive = false pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again. // If a timeout error happened in the background try the read again.
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive() msg, err = pgConn.frontend.Receive()
} }
} else { } else {
@ -460,8 +465,12 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
if err != nil { if err != nil {
// Close on anything other than timeout error - everything else is fatal // Close on anything other than timeout error - everything else is fatal
if err, ok := err.(net.Error); !(ok && err.Timeout()) { var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose() pgConn.asyncClose()
} else if isNetErr && netErr.Timeout() {
err = &ErrTimeout{err: err}
} }
return nil, err return nil, err
@ -476,8 +485,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.peekMessage() msg, err := pgConn.peekMessage()
if err != nil { if err != nil {
// Close on anything other than timeout error - everything else is fatal // Close on anything other than timeout error - everything else is fatal
if err, ok := err.(net.Error); !(ok && err.Timeout()) { var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose() pgConn.asyncClose()
} else if isNetErr && netErr.Timeout() {
err = &ErrTimeout{err: err}
} }
return nil, err return nil, err
@ -745,7 +758,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -918,7 +931,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default:
@ -964,7 +977,7 @@ func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default:
@ -1058,7 +1071,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.concludeCommand(nil, newContextAlreadyDoneError(ctx))
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@ -1098,7 +1111,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -1158,7 +1171,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
if ctx != context.Background() { if ctx != context.Background() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, newContextAlreadyDoneError(ctx)
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -1601,7 +1614,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default: