From 6996e8d6c546d45bab6f1e8b24c010f40f095e6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 09:09:22 -0500 Subject: [PATCH] Context errors returned instead of net.Error The net.Error caused by using SetDeadline to implement context cancellation shouldn't leak. fixes #80 --- errors.go | 10 ++++++++++ pgconn.go | 28 ++++++++++++++-------------- pgconn_test.go | 5 ++++- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/errors.go b/errors.go index 64401d65..a32b29c9 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "regexp" "strings" @@ -105,6 +106,15 @@ func (e *parseConfigError) Unwrap() error { return e.err } +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return &errTimeout{err: ctx.Err()} + } + return err +} + type pgconnError struct { msg string err error diff --git a/pgconn.go b/pgconn.go index a17a108d..43b13e43 100644 --- a/pgconn.go +++ b/pgconn.go @@ -271,7 +271,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: err} + return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -434,7 +434,10 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} + err = &pgconnError{ + msg: "receive message failed", + err: preferContextOverNetTimeoutError(ctx, err), + safeToRetry: true} } return msg, err } @@ -469,8 +472,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() - } else if isNetErr && netErr.Timeout() { - err = &errTimeout{err: err} } return nil, err @@ -489,8 +490,6 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() - } else if isNetErr && netErr.Timeout() { - err = &errTimeout{err: err} } return nil, err @@ -785,7 +784,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -888,7 +887,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { if ctx != context.Background() { select { case <-ctx.Done(): - return ctx.Err() + return newContextAlreadyDoneError(ctx) default: } @@ -899,7 +898,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return err + return preferContextOverNetTimeoutError(ctx, err) } switch msg.(type) { @@ -1136,7 +1135,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1196,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1255,7 +1254,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1287,7 +1286,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1329,7 +1328,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = err + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.asyncClose() return nil, mrr.err @@ -1536,6 +1535,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { + err = preferContextOverNetTimeoutError(rr.ctx, err) rr.concludeCommand(nil, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 7ceda791..c20b7425 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -585,6 +585,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): @@ -729,6 +730,7 @@ func TestConnExecParamsCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { @@ -1289,7 +1291,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() err = pgConn.WaitForNotification(ctx) - require.Equal(t, context.Canceled, err) + require.ErrorIs(t, err, context.Canceled) ensureConnValid(t, pgConn) } @@ -1308,6 +1310,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { err = pgConn.WaitForNotification(ctx) cancel() assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) ensureConnValid(t, pgConn) }