Refactor errors

- Use strongly typed errors internally
- SafeToRetry(error) streamlines retry logic over ErrNoBytesSent
- Timeout(error) removes the need to choose between returning a context
  and an i/o error
query-exec-mode
Jack Christensen 2019-08-27 18:01:59 -05:00
parent e6cf51b304
commit 138254da5b
4 changed files with 197 additions and 143 deletions

View File

@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) {
if strings.HasPrefix(connString, "postgres://") { if strings.HasPrefix(connString, "postgres://") {
err := addURLSettings(settings, connString) err := addURLSettings(settings, connString)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
} }
} else { } else {
err := addDSNSettings(settings, connString) err := addDSNSettings(settings, connString)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
} }
} }
} }
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil { if err != nil {
return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
} }
config := &Config{ config := &Config{
@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) {
if connectTimeout, present := settings["connect_timeout"]; present { if connectTimeout, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
} }
config.DialFunc = dialFunc config.DialFunc = dialFunc
} else { } else {
@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) {
port, err := parsePort(portStr) port, err := parsePort(portStr)
if err != nil { if err != nil {
return nil, errors.Errorf("invalid port: %w", err) return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
} }
var tlsConfigs []*tls.Config var tlsConfigs []*tls.Config
@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) {
var err error var err error
tlsConfigs, err = configTLS(settings) tlsConfigs, err = configTLS(settings)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
} }
} }
@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) {
if settings["target_session_attrs"] == "read-write" { if settings["target_session_attrs"] == "read-write" {
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" { } else if settings["target_session_attrs"] != "any" {
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
} }
return config, nil return config, nil

160
errors.go
View File

@ -2,22 +2,31 @@ package pgconn
import ( import (
"context" "context"
"fmt"
"net" "net"
"strings"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
// ErrTLSRefused occurs when the connection attempt requires TLS and the // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
// PostgreSQL server refuses to use TLS func SafeToRetry(err error) bool {
var ErrTLSRefused = errors.New("server refused TLS connection") if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another // Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
// action is attempted. // context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true.
var ErrConnBusy = errors.New("conn is busy") func Timeout(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used var netErr net.Error
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. return errors.As(err, &netErr) && netErr.Timeout()
var ErrNoBytesSent = errors.New("no bytes sent to server") }
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
@ -46,44 +55,107 @@ func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
} }
// linkedError connects two errors as if err wrapped next. type connectError struct {
type linkedError struct { config *Config
err error msg string
next error err error
} }
func (le *linkedError) Error() string { func (e *connectError) Error() string {
return le.err.Error() sb := &strings.Builder{}
} fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
func (le *linkedError) Is(target error) bool { fmt.Fprintf(sb, " (%s)", e.err.Error())
return errors.Is(le.err, target)
}
func (le *linkedError) As(target interface{}) bool {
return errors.As(le.err, target)
}
func (le *linkedError) Unwrap() error {
return le.next
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
} }
return err return sb.String()
} }
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. func (e *connectError) Unwrap() error {
func linkErrors(outer, inner error) error { return e.err
if outer == nil { }
return inner
} type connLockError struct {
if inner == nil { status string
return outer }
}
return &linkedError{err: outer, next: inner} func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
return e.err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
} }

125
pgconn.go
View File

@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
if err == nil { if err == nil {
break break
} else if err, ok := err.(*PgError); ok { } else if err, ok := err.(*PgError); ok {
return nil, err return nil, &connectError{config: config, msg: "server error", err: err}
} }
} }
if err != nil { if err != nil {
return nil, err return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
} }
if config.AfterConnect != nil { if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn) err := config.AfterConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.Errorf("AfterConnect: %v", err) return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
} }
} }
@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address) pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil { if err != nil {
return nil, err return nil, &connectError{config: config, msg: "dial error", err: err}
} }
pgConn.parameterStatuses = make(map[string]string) pgConn.parameterStatuses = make(map[string]string)
@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if fallbackConfig.TLSConfig != nil { if fallbackConfig.TLSConfig != nil {
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err return nil, &connectError{config: config, msg: "tls error", err: err}
} }
} }
@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
} }
for { for {
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
case *pgproto3.Authentication: case *pgproto3.Authentication:
if err = pgConn.rxAuthenticationX(msg); err != nil { if err = pgConn.rxAuthenticationX(msg); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err return nil, &connectError{config: config, msg: "failed handle authentication message", err: err}
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle pgConn.status = connStatusIdle
@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err := config.ValidateConnect(ctx, pgConn) err := config.ValidateConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.Errorf("ValidateConnect: %v", err) return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
} }
} }
return pgConn, nil return pgConn, nil
@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
default: default:
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.New("unexpected message") return nil, &connectError{config: config, msg: "received unexpected message", err: err}
} }
} }
} }
@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
} }
if response[0] != 'S' { if response[0] != 'S' {
return ErrTLSRefused return errors.New("server refused TLS connection")
} }
pgConn.conn = tls.Client(pgConn.conn, tlsConfig) pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
// See https://www.postgresql.org/docs/current/protocol.html. // See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return linkErrors(err, ErrNoBytesSent) return err
} }
defer pgConn.unlock() defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return linkErrors(ctx.Err(), ErrNoBytesSent) return &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { return &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return linkErrors(ctx.Err(), err)
} }
return nil return nil
@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
// See https://www.postgresql.org/docs/current/protocol.html. // See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
defer pgConn.unlock() defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
err = linkErrors(ctx.Err(), err) err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
} }
return msg, err return msg, err
} }
@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil { if err != nil {
return linkErrors(ctx.Err(), err) return err
} }
_, err = pgConn.conn.Read(make([]byte, 1)) _, err = pgConn.conn.Read(make([]byte, 1))
if err != io.EOF { if err != io.EOF {
return linkErrors(ctx.Err(), err) return err
} }
return pgConn.conn.Close() return pgConn.conn.Close()
@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool {
return pgConn.status < connStatusIdle return pgConn.status < connStatusIdle
} }
// lock locks the connection. It panics if the connection is already locked or is closed. // lock locks the connection.
func (pgConn *PgConn) lock() error { func (pgConn *PgConn) lock() error {
switch pgConn.status { switch pgConn.status {
case connStatusBusy: case connStatusBusy:
return ErrConnBusy // This only should be possible in case of an application bug. return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
case connStatusClosed: case connStatusClosed:
return errors.New("conn closed") return &connLockError{status: "conn closed"}
case connStatusUninitialized: case connStatusUninitialized:
return errors.New("conn uninitialized") return &connLockError{status: "conn uninitialized"}
} }
pgConn.status = connStatusBusy pgConn.status = connStatusBusy
return nil return nil
@ -527,13 +527,13 @@ type StatementDescription struct {
// allows Prepare to also to describe statements without creating a server-side prepared statement. // allows Prepare to also to describe statements without creating a server-side prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
defer pgConn.unlock() defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
} }
psd := &StatementDescription{Name: name, SQL: sql} psd := &StatementDescription{Name: name, SQL: sql}
@ -562,7 +559,7 @@ readloop:
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
_, err = cancelConn.Write(buf) _, err = cancelConn.Write(buf)
if err != nil { if err != nil {
return linkErrors(ctx.Err(), err) return err
} }
_, err = cancelConn.Read(buf) _, err = cancelConn.Read(buf)
if err != io.EOF { if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) return err
} }
return nil return nil
@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
for { for {
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
return linkErrors(ctx.Err(), err) return err
} }
switch msg.(type) { switch msg.(type) {
@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return &MultiResultReader{ return &MultiResultReader{
closed: true, closed: true,
err: linkErrors(err, ErrNoBytesSent), err: err,
} }
} }
@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default:
@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.hardClose() pgConn.hardClose()
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
multiResult.closed = true multiResult.closed = true
if n == 0 { multiResult.err = &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
multiResult.err = linkErrors(ctx.Err(), err)
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
} }
@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
result := &pgConn.resultReader result := &pgConn.resultReader
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) result.concludeCommand(nil, err)
result.closed = true result.closed = true
return result return result
} }
@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
err = linkErrors(err, ErrNoBytesSent)
}
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
// CopyTo executes the copy command sql and copies the results to w. // CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
pgConn.unlock() pgConn.unlock()
if n == 0 { return nil, &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
} }
// Read results // Read results
@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block. // could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
defer pgConn.unlock() defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { return nil, &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
} }
// Read until copy in response or error. // Read until copy in response or error.
@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
} }
@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
// Read results // Read results
@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
if err != nil { if err != nil {
mrr.pgConn.contextWatcher.Unwatch() mrr.pgConn.contextWatcher.Unwatch()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.err = err
mrr.closed = true mrr.closed = true
mrr.pgConn.hardClose() mrr.pgConn.hardClose()
return nil, mrr.err return nil, mrr.err
@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
} }
rr.commandTag = commandTag rr.commandTag = commandTag
rr.err = preferContextOverNetTimeoutError(rr.ctx, err) rr.err = err
rr.fieldDescriptions = nil rr.fieldDescriptions = nil
rr.rowValues = nil rr.rowValues = nil
rr.commandConcluded = true rr.commandConcluded = true
@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return &MultiResultReader{ return &MultiResultReader{
closed: true, closed: true,
err: linkErrors(err, ErrNoBytesSent), err: err,
} }
} }
@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default:

View File

@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) {
config.User = "pgxinvalidusertest" config.User = "pgxinvalidusertest"
conn, err := pgconn.ConnectConfig(context.Background(), config) _, err = pgconn.ConnectConfig(context.Background(), config)
if err == nil { require.Error(t, err)
conn.Close(context.Background()) pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
t.Fatal("expected err but got none")
}
pgErr, ok := err.(*pgconn.PgError)
if !ok { if !ok {
t.Fatalf("Expected to receive a PgError, instead received: %v", err) t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
} }
if pgErr.Code != "28000" && pgErr.Code != "28P01" { if pgErr.Code != "28000" && pgErr.Code != "28P01" {
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
assert.Nil(t, psd) assert.Nil(t, psd)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) {
for multiResult.NextResult() { for multiResult.NextResult() {
} }
err = multiResult.Close() err = multiResult.Close()
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
} }
@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) {
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.Error(t, err) assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
} }
@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err) require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled)) assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
} }
@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err) require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled)) assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
_, err = pgConn.ExecBatch(ctx, batch).ReadAll() _, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) {
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err) assert.Error(t, err)
assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.Equal(t, "conn busy", err.Error())
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
results, err := mrr.ReadAll() results, err := mrr.ReadAll()
assert.NoError(t, err) assert.NoError(t, err)
@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx) err = pgConn.WaitForNotification(ctx)
cancel() cancel()
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.True(t, pgconn.Timeout(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag(nil), res)
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag(nil), res)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel() cancel()
assert.Equal(t, int64(0), ct.RowsAffected()) assert.Equal(t, int64(0), ct.RowsAffected())
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Error(t, err)
assert.True(t, pgConn.IsClosed()) assert.True(t, pgConn.IsClosed())
} }
@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), ct) assert.Equal(t, pgconn.CommandTag(nil), ct)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)