Implement timeout error

Signed-off-by: Michael Darr <michael.e.darr@gmail.com>
query-exec-mode
Michael Darr 2021-06-29 14:24:09 -04:00 committed by Jack Christensen
parent a123e5b4e5
commit c0b4d3bc05
2 changed files with 65 additions and 19 deletions

View File

@ -18,15 +18,11 @@ func SafeToRetry(err error) bool {
return false
}
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
var timeoutErr *ErrTimeout
return errors.As(err, &timeoutErr)
}
// PgError represents an error reported by the PostgreSQL server. See
@ -134,6 +130,32 @@ func (e *pgconnError) Unwrap() error {
return e.err
}
// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type ErrTimeout struct {
err error
}
func (e *ErrTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *ErrTimeout) SafeToRetry() bool {
var ctxErr *contextAlreadyDoneError
if errors.As(e, &ctxErr) {
return ctxErr.SafeToRetry()
}
var netErr net.Error
if errors.As(e, &netErr) {
return netErr.Temporary()
}
return false
}
func (e *ErrTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
@ -150,6 +172,17 @@ func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its
// deadline passed, the returned error is also wrapped by `ErrTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
ctxErr := ctx.Err()
err = &contextAlreadyDoneError{err: ctxErr}
if ctxErr != nil {
err = &ErrTimeout{err: err}
}
return err
}
type writeError struct {
err error
safeToRetry bool

View File

@ -217,6 +217,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
err = &ErrTimeout{err: err}
}
return nil, &connectError{config: config, msg: "dial error", err: err}
}
@ -389,7 +393,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if ctx != context.Background() {
select {
case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()}
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -421,7 +425,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
return nil, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -451,7 +455,8 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
@ -460,8 +465,12 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
if err != nil {
// Close on anything other than timeout error - everything else is fatal
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
} else if isNetErr && netErr.Timeout() {
err = &ErrTimeout{err: err}
}
return nil, err
@ -476,8 +485,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.peekMessage()
if err != nil {
// Close on anything other than timeout error - everything else is fatal
if err, ok := err.(net.Error); !(ok && err.Timeout()) {
var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
} else if isNetErr && netErr.Timeout() {
err = &ErrTimeout{err: err}
}
return nil, err
@ -745,7 +758,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
return nil, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -918,7 +931,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock()
return multiResult
default:
@ -964,7 +977,7 @@ func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock()
return multiResult
default:
@ -1058,7 +1071,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
if ctx != context.Background() {
select {
case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.concludeCommand(nil, newContextAlreadyDoneError(ctx))
result.closed = true
pgConn.unlock()
return result
@ -1098,7 +1111,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
select {
case <-ctx.Done():
pgConn.unlock()
return nil, &contextAlreadyDoneError{err: ctx.Err()}
return nil, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -1158,7 +1171,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
return nil, newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -1601,7 +1614,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
multiResult.err = newContextAlreadyDoneError(ctx)
pgConn.unlock()
return multiResult
default: