mirror of https://github.com/jackc/pgx.git
Context errors returned instead of net.Error
The net.Error caused by using SetDeadline to implement context cancellation shouldn't leak. fixes #80query-exec-mode
parent
13d454882b
commit
6996e8d6c5
10
errors.go
10
errors.go
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
@ -105,6 +106,15 @@ func (e *parseConfigError) Unwrap() error {
|
|||
return e.err
|
||||
}
|
||||
|
||||
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||
// true. Otherwise returns err.
|
||||
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||
return &errTimeout{err: ctx.Err()}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type pgconnError struct {
|
||||
msg string
|
||||
err error
|
||||
|
|
28
pgconn.go
28
pgconn.go
|
@ -271,7 +271,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
if err, ok := err.(*PgError); ok {
|
||||
return nil, err
|
||||
}
|
||||
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
|
||||
return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)}
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -434,7 +434,10 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
|
|||
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
|
||||
err = &pgconnError{
|
||||
msg: "receive message failed",
|
||||
err: preferContextOverNetTimeoutError(ctx, err),
|
||||
safeToRetry: true}
|
||||
}
|
||||
return msg, err
|
||||
}
|
||||
|
@ -469,8 +472,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
|
|||
isNetErr := errors.As(err, &netErr)
|
||||
if !(isNetErr && netErr.Timeout()) {
|
||||
pgConn.asyncClose()
|
||||
} else if isNetErr && netErr.Timeout() {
|
||||
err = &errTimeout{err: err}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
|
@ -489,8 +490,6 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
|||
isNetErr := errors.As(err, &netErr)
|
||||
if !(isNetErr && netErr.Timeout()) {
|
||||
pgConn.asyncClose()
|
||||
} else if isNetErr && netErr.Timeout() {
|
||||
err = &errTimeout{err: err}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
|
@ -785,7 +784,7 @@ readloop:
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -888,7 +887,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
|||
if ctx != context.Background() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return newContextAlreadyDoneError(ctx)
|
||||
default:
|
||||
}
|
||||
|
||||
|
@ -899,7 +898,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
|||
for {
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg.(type) {
|
||||
|
@ -1136,7 +1135,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -1196,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -1255,7 +1254,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -1287,7 +1286,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -1329,7 +1328,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
|||
|
||||
if err != nil {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.err = err
|
||||
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
|
||||
mrr.closed = true
|
||||
mrr.pgConn.asyncClose()
|
||||
return nil, mrr.err
|
||||
|
@ -1536,6 +1535,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
err = preferContextOverNetTimeoutError(rr.ctx, err)
|
||||
rr.concludeCommand(nil, err)
|
||||
rr.pgConn.contextWatcher.Unwatch()
|
||||
rr.closed = true
|
||||
|
|
|
@ -585,6 +585,7 @@ func TestConnExecContextCanceled(t *testing.T) {
|
|||
}
|
||||
err = multiResult.Close()
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
select {
|
||||
case <-pgConn.CleanupDone():
|
||||
|
@ -729,6 +730,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
|||
commandTag, err := result.Close()
|
||||
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
select {
|
||||
|
@ -1289,7 +1291,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
err = pgConn.WaitForNotification(ctx)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -1308,6 +1310,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
|
|||
err = pgConn.WaitForNotification(ctx)
|
||||
cancel()
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue