From 138254da5b02b80a548f7858f01636f9a426b918 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:01:59 -0500 Subject: [PATCH] Refactor errors - Use strongly typed errors internally - SafeToRetry(error) streamlines retry logic over ErrNoBytesSent - Timeout(error) removes the need to choose between returning a context and an i/o error --- config.go | 14 ++--- errors.go | 160 +++++++++++++++++++++++++++++++++++-------------- pgconn.go | 125 +++++++++++++++++--------------------- pgconn_test.go | 41 ++++++------- 4 files changed, 197 insertions(+), 143 deletions(-) diff --git a/config.go b/config.go index cb153c77..d24d0202 100644 --- a/config.go +++ b/config.go @@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) { if strings.HasPrefix(connString, "postgres://") { err := addURLSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { err := addDSNSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { - return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} } config := &Config{ @@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) { if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } config.DialFunc = dialFunc } else { @@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, errors.Errorf("invalid port: %w", err) + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { var err error tlsConfigs, err = configTLS(settings) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } } @@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} } return config, nil diff --git a/errors.go b/errors.go index 4f8af407..a088dcdd 100644 --- a/errors.go +++ b/errors.go @@ -2,22 +2,31 @@ package pgconn import ( "context" + "fmt" "net" + "strings" errors "golang.org/x/xerrors" ) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } -// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used -// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. -var ErrNoBytesSent = errors.New("no bytes sent to server") + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for @@ -46,44 +55,107 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } -// linkedError connects two errors as if err wrapped next. -type linkedError struct { - err error - next error +type connectError struct { + config *Config + msg string + err error } -func (le *linkedError) Error() string { - return le.err.Error() -} - -func (le *linkedError) Is(target error) bool { - return errors.Is(le.err, target) -} - -func (le *linkedError) As(target interface{}) bool { - return errors.As(le.err, target) -} - -func (le *linkedError) Unwrap() error { - return le.next -} - -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) } - return err + return sb.String() } -// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. -func linkErrors(outer, inner error) error { - if outer == nil { - return inner - } - if inner == nil { - return outer - } - return &linkedError{err: outer, next: inner} +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) + } + return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() + } + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err } diff --git a/pgconn.go b/pgconn.go index 7d301af2..347acf80 100644 --- a/pgconn.go +++ b/pgconn.go @@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, err + return nil, &connectError{config: config, msg: "server error", err: err} } } if err != nil { - return nil, err + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} } } @@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { - return nil, err + return nil, &connectError{config: config, msg: "dial error", err: err} } pgConn.parameterStatuses = make(map[string]string) @@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if fallbackConfig.TLSConfig != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "tls error", err: err} } } @@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() - return nil, err + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: err} } switch msg := msg.(type) { @@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("ValidateConnect: %v", err) + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, errors.New("unexpected message") + return nil, &connectError{config: config, msg: "received unexpected message", err: err} } } } @@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { } if response[0] != 'S' { - return ErrTLSRefused + return errors.New("server refused TLS connection") } pgConn.conn = tls.Client(pgConn.conn, tlsConfig) @@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if err := pgConn.lock(); err != nil { - return linkErrors(err, ErrNoBytesSent) + return err } defer pgConn.unlock() select { case <-ctx.Done(): - return linkErrors(ctx.Err(), ErrNoBytesSent) + return &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return linkErrors(ctx.Err(), err) + return &writeError{err: err, safeToRetry: n == 0} } return nil @@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = linkErrors(ctx.Err(), err) + err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} } return msg, err } @@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return linkErrors(ctx.Err(), err) + return err } return pgConn.conn.Close() @@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } -// lock locks the connection. It panics if the connection is already locked or is closed. +// lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return ErrConnBusy // This only should be possible in case of an application bug. + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. case connStatusClosed: - return errors.New("conn closed") + return &connLockError{status: "conn closed"} case connStatusUninitialized: - return errors.New("conn uninitialized") + return &connLockError{status: "conn uninitialized"} } pgConn.status = connStatusBusy return nil @@ -527,13 +527,13 @@ type StatementDescription struct { // allows Prepare to also to describe statements without creating a server-side prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} @@ -562,7 +559,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) + return err } return nil @@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return linkErrors(ctx.Err(), err) + return err } switch msg.(type) { @@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: @@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - multiResult.err = linkErrors(ctx.Err(), err) + multiResult.err = &writeError{err: err, safeToRetry: n == 0} pgConn.unlock() return multiResult } @@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if err := pgConn.lock(); err != nil { - result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.concludeCommand(nil, err) result.closed = true return result } @@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.closed = true pgConn.unlock() return result @@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - result.concludeCommand(nil, linkErrors(ctx.Err(), err)) + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } select { case <-ctx.Done(): pgConn.unlock() - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.hardClose() pgConn.unlock() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read results @@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read until copy in response or error. @@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } } @@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } // Read results @@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = err mrr.closed = true mrr.pgConn.hardClose() return nil, mrr.err @@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { } rr.commandTag = commandTag - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.err = err rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true @@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 64628262..3fbdf8df 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) { config.User = "pgxinvalidusertest" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if err == nil { - conn.Close(context.Background()) - t.Fatal("expected err but got none") - } - pgErr, ok := err.(*pgconn.PgError) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) if !ok { - t.Fatalf("Expected to receive a PgError, instead received: %v", err) + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) } if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) @@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { assert.Nil(t, psd) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) { for multiResult.NextResult() { } err = multiResult.Close() - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) { _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) results, err := mrr.ReadAll() assert.NoError(t, err) @@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.True(t, pgconn.Timeout(err)) ensureConnValid(t, pgConn) } @@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) @@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.True(t, pgConn.IsClosed()) } @@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn)