Context errors returned instead of net.Error

The net.Error caused by using SetDeadline to implement context
cancellation shouldn't leak.

fixes #80
query-exec-mode
Jack Christensen 2021-07-24 09:09:22 -05:00
parent 13d454882b
commit 6996e8d6c5
3 changed files with 28 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strings"
@ -105,6 +106,15 @@ func (e *parseConfigError) Unwrap() error {
return e.err
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return &errTimeout{err: ctx.Err()}
}
return err
}
type pgconnError struct {
msg string
err error

View File

@ -271,7 +271,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)}
}
switch msg := msg.(type) {
@ -434,7 +434,10 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
msg, err := pgConn.receiveMessage()
if err != nil {
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
err = &pgconnError{
msg: "receive message failed",
err: preferContextOverNetTimeoutError(ctx, err),
safeToRetry: true}
}
return msg, err
}
@ -469,8 +472,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
} else if isNetErr && netErr.Timeout() {
err = &errTimeout{err: err}
}
return nil, err
@ -489,8 +490,6 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
} else if isNetErr && netErr.Timeout() {
err = &errTimeout{err: err}
}
return nil, err
@ -785,7 +784,7 @@ readloop:
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, err
return nil, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -888,7 +887,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
if ctx != context.Background() {
select {
case <-ctx.Done():
return ctx.Err()
return newContextAlreadyDoneError(ctx)
default:
}
@ -899,7 +898,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
for {
msg, err := pgConn.receiveMessage()
if err != nil {
return err
return preferContextOverNetTimeoutError(ctx, err)
}
switch msg.(type) {
@ -1136,7 +1135,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, err
return nil, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -1196,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, err
return nil, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -1255,7 +1254,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, err
return nil, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -1287,7 +1286,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return nil, err
return nil, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
@ -1329,7 +1328,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
if err != nil {
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = err
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.asyncClose()
return nil, mrr.err
@ -1536,6 +1535,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
}
if err != nil {
err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.concludeCommand(nil, err)
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true

View File

@ -585,6 +585,7 @@ func TestConnExecContextCanceled(t *testing.T) {
}
err = multiResult.Close()
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
@ -729,6 +730,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.True(t, pgConn.IsClosed())
select {
@ -1289,7 +1291,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = pgConn.WaitForNotification(ctx)
require.Equal(t, context.Canceled, err)
require.ErrorIs(t, err, context.Canceled)
ensureConnValid(t, pgConn)
}
@ -1308,6 +1310,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
err = pgConn.WaitForNotification(ctx)
cancel()
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
ensureConnValid(t, pgConn)
}