Refactor errors

- Use strongly typed errors internally
- SafeToRetry(error) streamlines retry logic over ErrNoBytesSent
- Timeout(error) removes the need to choose between returning a context
  and an i/o error
query-exec-mode
Jack Christensen 2019-08-27 18:01:59 -05:00
parent e6cf51b304
commit 138254da5b
4 changed files with 197 additions and 143 deletions

View File

@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) {
if strings.HasPrefix(connString, "postgres://") {
err := addURLSettings(settings, connString)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
err := addDSNSettings(settings, connString)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil {
return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err)
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
}
config := &Config{
@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) {
if connectTimeout, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.DialFunc = dialFunc
} else {
@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) {
port, err := parsePort(portStr)
if err != nil {
return nil, errors.Errorf("invalid port: %w", err)
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) {
var err error
tlsConfigs, err = configTLS(settings)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) {
if settings["target_session_attrs"] == "read-write" {
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" {
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
}
return config, nil

160
errors.go
View File

@ -2,22 +2,31 @@ package pgconn
import (
"context"
"fmt"
"net"
"strings"
errors "golang.org/x/xerrors"
)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error.
var ErrNoBytesSent = errors.New("no bytes sent to server")
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
@ -46,44 +55,107 @@ func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// linkedError connects two errors as if err wrapped next.
type linkedError struct {
err error
next error
type connectError struct {
config *Config
msg string
err error
}
func (le *linkedError) Error() string {
return le.err.Error()
}
func (le *linkedError) Is(target error) bool {
return errors.Is(le.err, target)
}
func (le *linkedError) As(target interface{}) bool {
return errors.As(le.err, target)
}
func (le *linkedError) Unwrap() error {
return le.next
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
func (e *connectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return err
return sb.String()
}
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
func linkErrors(outer, inner error) error {
if outer == nil {
return inner
}
if inner == nil {
return outer
}
return &linkedError{err: outer, next: inner}
func (e *connectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
return e.err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}

125
pgconn.go
View File

@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
if err == nil {
break
} else if err, ok := err.(*PgError); ok {
return nil, err
return nil, &connectError{config: config, msg: "server error", err: err}
}
}
if err != nil {
return nil, err
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, errors.Errorf("AfterConnect: %v", err)
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
}
}
@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil {
return nil, err
return nil, &connectError{config: config, msg: "dial error", err: err}
}
pgConn.parameterStatuses = make(map[string]string)
@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if fallbackConfig.TLSConfig != nil {
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
pgConn.conn.Close()
return nil, err
return nil, &connectError{config: config, msg: "tls error", err: err}
}
}
@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
pgConn.conn.Close()
return nil, err
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
}
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.conn.Close()
return nil, err
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
}
switch msg := msg.(type) {
@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
case *pgproto3.Authentication:
if err = pgConn.rxAuthenticationX(msg); err != nil {
pgConn.conn.Close()
return nil, err
return nil, &connectError{config: config, msg: "failed handle authentication message", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err := config.ValidateConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, errors.Errorf("ValidateConnect: %v", err)
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
return nil, errors.New("unexpected message")
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
}
}
}
@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
}
if response[0] != 'S' {
return ErrTLSRefused
return errors.New("server refused TLS connection")
}
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil {
return linkErrors(err, ErrNoBytesSent)
return err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return linkErrors(ctx.Err(), ErrNoBytesSent)
return &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return linkErrors(ctx.Err(), err)
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
msg, err := pgConn.receiveMessage()
if err != nil {
err = linkErrors(ctx.Err(), err)
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
}
return msg, err
}
@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil {
return linkErrors(ctx.Err(), err)
return err
}
_, err = pgConn.conn.Read(make([]byte, 1))
if err != io.EOF {
return linkErrors(ctx.Err(), err)
return err
}
return pgConn.conn.Close()
@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool {
return pgConn.status < connStatusIdle
}
// lock locks the connection. It panics if the connection is already locked or is closed.
// lock locks the connection.
func (pgConn *PgConn) lock() error {
switch pgConn.status {
case connStatusBusy:
return ErrConnBusy // This only should be possible in case of an application bug.
return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
case connStatusClosed:
return errors.New("conn closed")
return &connLockError{status: "conn closed"}
case connStatusUninitialized:
return errors.New("conn uninitialized")
return &connLockError{status: "conn uninitialized"}
}
pgConn.status = connStatusBusy
return nil
@ -527,13 +527,13 @@ type StatementDescription struct {
// allows Prepare to also to describe statements without creating a server-side prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0}
}
psd := &StatementDescription{Name: name, SQL: sql}
@ -562,7 +559,7 @@ readloop:
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return linkErrors(ctx.Err(), err)
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
return err
}
return nil
@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
for {
msg, err := pgConn.receiveMessage()
if err != nil {
return linkErrors(ctx.Err(), err)
return err
}
switch msg.(type) {
@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: linkErrors(err, ErrNoBytesSent),
err: err,
}
}
@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.hardClose()
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
multiResult.err = linkErrors(ctx.Err(), err)
multiResult.err = &writeError{err: err, safeToRetry: n == 0}
pgConn.unlock()
return multiResult
}
@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
result := &pgConn.resultReader
if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent))
result.concludeCommand(nil, err)
result.closed = true
return result
}
@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
select {
case <-ctx.Done():
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent))
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.closed = true
pgConn.unlock()
return result
@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
select {
case <-ctx.Done():
pgConn.unlock()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
if err != nil {
pgConn.hardClose()
pgConn.unlock()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
return nil, &writeError{err: err, safeToRetry: n == 0}
}
// Read results
@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
return nil, &writeError{err: err, safeToRetry: n == 0}
}
// Read until copy in response or error.
@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
}
@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
// Read results
@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
if err != nil {
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.err = err
mrr.closed = true
mrr.pgConn.hardClose()
return nil, mrr.err
@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
}
rr.commandTag = commandTag
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.err = err
rr.fieldDescriptions = nil
rr.rowValues = nil
rr.commandConcluded = true
@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: linkErrors(err, ErrNoBytesSent),
err: err,
}
}
@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:

View File

@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) {
config.User = "pgxinvalidusertest"
conn, err := pgconn.ConnectConfig(context.Background(), config)
if err == nil {
conn.Close(context.Background())
t.Fatal("expected err but got none")
}
pgErr, ok := err.(*pgconn.PgError)
_, err = pgconn.ConnectConfig(context.Background(), config)
require.Error(t, err)
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
if !ok {
t.Fatalf("Expected to receive a PgError, instead received: %v", err)
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
}
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
assert.Nil(t, psd)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) {
for multiResult.NextResult() {
}
err = multiResult.Close()
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
}
@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) {
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
}
@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn)
}
@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
}
@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn)
}
@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) {
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.Equal(t, "conn busy", err.Error())
assert.True(t, pgconn.SafeToRetry(err))
results, err := mrr.ReadAll()
assert.NoError(t, err)
@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx)
cancel()
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.True(t, pgconn.Timeout(err))
ensureConnValid(t, pgConn)
}
@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag(nil), res)
assert.True(t, pgConn.IsClosed())
@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), res)
ensureConnValid(t, pgConn)
@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel()
assert.Equal(t, int64(0), ct.RowsAffected())
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Error(t, err)
assert.True(t, pgConn.IsClosed())
}
@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), ct)
ensureConnValid(t, pgConn)