From 1a314bda3b604205092e4f48df90bbcbfaaef960 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 17 Sep 2022 10:18:06 -0500
Subject: [PATCH] pgconn.Timeout() no longer considers `context.Canceled` as a
 timeout error.

https://github.com/jackc/pgconn/issues/81
---
 CHANGELOG.md     |  2 ++
 pgconn/errors.go | 17 +++++++++++------
 pgconn/pgconn.go | 28 ++++++++++++----------------
 3 files changed, 25 insertions(+), 22 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2755a4ca..32acfdda 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -22,6 +22,8 @@ pgconn now supports pipeline mode.
 
 `*PgConn.ReceiveResults` removed. Use pipeline mode instead.
 
+`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
+
 ## pgxpool
 
 `Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
diff --git a/pgconn/errors.go b/pgconn/errors.go
index 4254535e..3c54bbec 100644
--- a/pgconn/errors.go
+++ b/pgconn/errors.go
@@ -19,7 +19,7 @@ func SafeToRetry(err error) bool {
 }
 
 // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
-// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
+// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
 func Timeout(err error) bool {
 	var timeoutErr *errTimeout
 	return errors.As(err, &timeoutErr)
@@ -106,11 +106,16 @@ func (e *parseConfigError) Unwrap() error {
 	return e.err
 }
 
-// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
-// true. Otherwise returns err.
-func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
-	if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
-		return &errTimeout{err: ctx.Err()}
+func normalizeTimeoutError(ctx context.Context, err error) error {
+	if err, ok := err.(net.Error); ok && err.Timeout() {
+		if ctx.Err() == context.Canceled {
+			// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
+			return context.Canceled
+		} else if ctx.Err() == context.DeadlineExceeded {
+			return &errTimeout{err: ctx.Err()}
+		} else {
+			return &errTimeout{err: err}
+		}
 	}
 	return err
 }
diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go
index 44de2897..59fa35c6 100644
--- a/pgconn/pgconn.go
+++ b/pgconn/pgconn.go
@@ -255,11 +255,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
 	network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
 	netConn, err := config.DialFunc(ctx, network, address)
 	if err != nil {
-		var netErr net.Error
-		if errors.As(err, &netErr) && netErr.Timeout() {
-			err = &errTimeout{err: err}
-		}
-		return nil, &connectError{config: config, msg: "dial error", err: err}
+		return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
 	}
 	nbNetConn := nbconn.NewNetConn(netConn, false)
 
@@ -314,7 +310,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
 			if err, ok := err.(*PgError); ok {
 				return nil, err
 			}
-			return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)}
+			return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
 		}
 
 		switch msg := msg.(type) {
@@ -448,7 +444,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
 	if err != nil {
 		err = &pgconnError{
 			msg:         "receive message failed",
-			err:         preferContextOverNetTimeoutError(ctx, err),
+			err:         normalizeTimeoutError(ctx, err),
 			safeToRetry: true}
 	}
 	return msg, err
@@ -794,7 +790,7 @@ readloop:
 		msg, err := pgConn.receiveMessage()
 		if err != nil {
 			pgConn.asyncClose()
-			return nil, preferContextOverNetTimeoutError(ctx, err)
+			return nil, normalizeTimeoutError(ctx, err)
 		}
 
 		switch msg := msg.(type) {
@@ -907,7 +903,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
 	for {
 		msg, err := pgConn.receiveMessage()
 		if err != nil {
-			return preferContextOverNetTimeoutError(ctx, err)
+			return normalizeTimeoutError(ctx, err)
 		}
 
 		switch msg.(type) {
@@ -1106,7 +1102,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
 		msg, err := pgConn.receiveMessage()
 		if err != nil {
 			pgConn.asyncClose()
-			return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
+			return CommandTag{}, normalizeTimeoutError(ctx, err)
 		}
 
 		switch msg := msg.(type) {
@@ -1203,7 +1199,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
 					break
 				}
 				pgConn.asyncClose()
-				return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
+				return CommandTag{}, normalizeTimeoutError(ctx, err)
 			}
 
 			switch msg := msg.(type) {
@@ -1238,7 +1234,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
 		msg, err := pgConn.receiveMessage()
 		if err != nil {
 			pgConn.asyncClose()
-			return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
+			return CommandTag{}, normalizeTimeoutError(ctx, err)
 		}
 
 		switch msg := msg.(type) {
@@ -1281,7 +1277,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
 
 	if err != nil {
 		mrr.pgConn.contextWatcher.Unwatch()
-		mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
+		mrr.err = normalizeTimeoutError(mrr.ctx, err)
 		mrr.closed = true
 		mrr.pgConn.asyncClose()
 		return nil, mrr.err
@@ -1497,7 +1493,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
 	}
 
 	if err != nil {
-		err = preferContextOverNetTimeoutError(rr.ctx, err)
+		err = normalizeTimeoutError(rr.ctx, err)
 		rr.concludeCommand(CommandTag{}, err)
 		rr.pgConn.contextWatcher.Unwatch()
 		rr.closed = true
@@ -1814,7 +1810,7 @@ func (p *Pipeline) Flush() error {
 
 	err := p.conn.frontend.Flush()
 	if err != nil {
-		err = preferContextOverNetTimeoutError(p.ctx, err)
+		err = normalizeTimeoutError(p.ctx, err)
 
 		p.conn.asyncClose()
 
@@ -1901,7 +1897,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
 		msg, err := p.conn.receiveMessage()
 		if err != nil {
 			p.conn.asyncClose()
-			return nil, preferContextOverNetTimeoutError(p.ctx, err)
+			return nil, normalizeTimeoutError(p.ctx, err)
 		}
 
 		switch msg := msg.(type) {