mirror of https://github.com/jackc/pgx.git
Refactor errors
- Use strongly typed errors internally - SafeToRetry(error) streamlines retry logic over ErrNoBytesSent - Timeout(error) removes the need to choose between returning a context and an i/o errorquery-exec-mode
parent
e6cf51b304
commit
138254da5b
14
config.go
14
config.go
|
@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
if strings.HasPrefix(connString, "postgres://") {
|
||||
err := addURLSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
err := addDSNSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err)
|
||||
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
|
@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
if connectTimeout, present := settings["connect_timeout"]; present {
|
||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.DialFunc = dialFunc
|
||||
} else {
|
||||
|
@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid port: %w", err)
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
|
@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
var err error
|
||||
tlsConfigs, err = configTLS(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
if settings["target_session_attrs"] == "read-write" {
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||
} else if settings["target_session_attrs"] != "any" {
|
||||
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
|
||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
|
|
160
errors.go
160
errors.go
|
@ -2,22 +2,31 @@ package pgconn
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||
// PostgreSQL server refuses to use TLS
|
||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||
func SafeToRetry(err error) bool {
|
||||
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||
return e.SafeToRetry()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
|
||||
// action is attempted.
|
||||
var ErrConnBusy = errors.New("conn is busy")
|
||||
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
|
||||
// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true.
|
||||
func Timeout(err error) bool {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used
|
||||
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error.
|
||||
var ErrNoBytesSent = errors.New("no bytes sent to server")
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||
|
@ -46,44 +55,107 @@ func (pe *PgError) Error() string {
|
|||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// linkedError connects two errors as if err wrapped next.
|
||||
type linkedError struct {
|
||||
err error
|
||||
next error
|
||||
type connectError struct {
|
||||
config *Config
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (le *linkedError) Error() string {
|
||||
return le.err.Error()
|
||||
}
|
||||
|
||||
func (le *linkedError) Is(target error) bool {
|
||||
return errors.Is(le.err, target)
|
||||
}
|
||||
|
||||
func (le *linkedError) As(target interface{}) bool {
|
||||
return errors.As(le.err, target)
|
||||
}
|
||||
|
||||
func (le *linkedError) Unwrap() error {
|
||||
return le.next
|
||||
}
|
||||
|
||||
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||
// true. Otherwise returns err.
|
||||
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
func (e *connectError) Error() string {
|
||||
sb := &strings.Builder{}
|
||||
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||
if e.err != nil {
|
||||
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||
}
|
||||
return err
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
|
||||
func linkErrors(outer, inner error) error {
|
||||
if outer == nil {
|
||||
return inner
|
||||
}
|
||||
if inner == nil {
|
||||
return outer
|
||||
}
|
||||
return &linkedError{err: outer, next: inner}
|
||||
func (e *connectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type connLockError struct {
|
||||
status string
|
||||
}
|
||||
|
||||
func (e *connLockError) SafeToRetry() bool {
|
||||
return true // a lock failure by definition happens before the connection is used.
|
||||
}
|
||||
|
||||
func (e *connLockError) Error() string {
|
||||
return e.status
|
||||
}
|
||||
|
||||
type parseConfigError struct {
|
||||
connString string
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Error() string {
|
||||
if e.err == nil {
|
||||
return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg)
|
||||
}
|
||||
return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type pgconnError struct {
|
||||
msg string
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *pgconnError) Error() string {
|
||||
if e.msg == "" {
|
||||
return e.err.Error()
|
||||
}
|
||||
if e.err == nil {
|
||||
return e.msg
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *pgconnError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *pgconnError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type contextAlreadyDoneError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Error() string {
|
||||
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type writeError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *writeError) Error() string {
|
||||
return fmt.Sprintf("write failed: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *writeError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *writeError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
|
125
pgconn.go
125
pgconn.go
|
@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||
if err == nil {
|
||||
break
|
||||
} else if err, ok := err.(*PgError); ok {
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "server error", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
|
||||
}
|
||||
|
||||
if config.AfterConnect != nil {
|
||||
err := config.AfterConnect(ctx, pgConn)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, errors.Errorf("AfterConnect: %v", err)
|
||||
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||
pgConn.conn, err = config.DialFunc(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||
}
|
||||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
|
@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
if fallbackConfig.TLSConfig != nil {
|
||||
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "tls error", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
|
||||
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
if err, ok := err.(*PgError); ok {
|
||||
return nil, err
|
||||
}
|
||||
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
case *pgproto3.Authentication:
|
||||
if err = pgConn.rxAuthenticationX(msg); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "failed handle authentication message", err: err}
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
pgConn.status = connStatusIdle
|
||||
|
@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
err := config.ValidateConnect(ctx, pgConn)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, errors.Errorf("ValidateConnect: %v", err)
|
||||
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
|
||||
}
|
||||
}
|
||||
return pgConn, nil
|
||||
|
@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
return nil, ErrorResponseToPgError(msg)
|
||||
default:
|
||||
pgConn.conn.Close()
|
||||
return nil, errors.New("unexpected message")
|
||||
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
|||
}
|
||||
|
||||
if response[0] != 'S' {
|
||||
return ErrTLSRefused
|
||||
return errors.New("server refused TLS connection")
|
||||
}
|
||||
|
||||
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
|
||||
|
@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
|
|||
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return linkErrors(err, ErrNoBytesSent)
|
||||
return err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
|||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
|||
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
|
|||
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
err = linkErrors(ctx.Err(), err)
|
||||
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
|
||||
}
|
||||
return msg, err
|
||||
}
|
||||
|
@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
|
|||
|
||||
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = pgConn.conn.Read(make([]byte, 1))
|
||||
if err != io.EOF {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
return pgConn.conn.Close()
|
||||
|
@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool {
|
|||
return pgConn.status < connStatusIdle
|
||||
}
|
||||
|
||||
// lock locks the connection. It panics if the connection is already locked or is closed.
|
||||
// lock locks the connection.
|
||||
func (pgConn *PgConn) lock() error {
|
||||
switch pgConn.status {
|
||||
case connStatusBusy:
|
||||
return ErrConnBusy // This only should be possible in case of an application bug.
|
||||
return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
|
||||
case connStatusClosed:
|
||||
return errors.New("conn closed")
|
||||
return &connLockError{status: "conn closed"}
|
||||
case connStatusUninitialized:
|
||||
return errors.New("conn uninitialized")
|
||||
return &connLockError{status: "conn uninitialized"}
|
||||
}
|
||||
pgConn.status = connStatusBusy
|
||||
return nil
|
||||
|
@ -527,13 +527,13 @@ type StatementDescription struct {
|
|||
// allows Prepare to also to describe statements without creating a server-side prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
psd := &StatementDescription{Name: name, SQL: sql}
|
||||
|
@ -562,7 +559,7 @@ readloop:
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
|||
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
|
||||
_, err = cancelConn.Write(buf)
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cancelConn.Read(buf)
|
||||
if err != io.EOF {
|
||||
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
|||
for {
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg.(type) {
|
||||
|
@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
|
@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
pgConn.hardClose()
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
multiResult.closed = true
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
multiResult.err = linkErrors(ctx.Err(), err)
|
||||
multiResult.err = &writeError{err: err, safeToRetry: n == 0}
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
|
@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
|||
result := &pgConn.resultReader
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent))
|
||||
result.concludeCommand(nil, err)
|
||||
result.closed = true
|
||||
return result
|
||||
}
|
||||
|
@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
|||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent))
|
||||
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
|
@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
|
|||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
|
||||
result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
|
@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
|
|||
// CopyTo executes the copy command sql and copies the results to w.
|
||||
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pgConn.unlock()
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
pgConn.unlock()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
// Read results
|
||||
|
@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
// could still block.
|
||||
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
// Read until copy in response or error.
|
||||
|
@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
_, err = pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
_, err = pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read results
|
||||
|
@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
|||
|
||||
if err != nil {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
|
||||
mrr.err = err
|
||||
mrr.closed = true
|
||||
mrr.pgConn.hardClose()
|
||||
return nil, mrr.err
|
||||
|
@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
|
|||
}
|
||||
|
||||
rr.commandTag = commandTag
|
||||
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
|
||||
rr.err = err
|
||||
rr.fieldDescriptions = nil
|
||||
rr.rowValues = nil
|
||||
rr.commandConcluded = true
|
||||
|
@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
|
|
|
@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) {
|
|||
|
||||
config.User = "pgxinvalidusertest"
|
||||
|
||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
if err == nil {
|
||||
conn.Close(context.Background())
|
||||
t.Fatal("expected err but got none")
|
||||
}
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
_, err = pgconn.ConnectConfig(context.Background(), config)
|
||||
require.Error(t, err)
|
||||
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
|
||||
if !ok {
|
||||
t.Fatalf("Expected to receive a PgError, instead received: %v", err)
|
||||
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
|
||||
}
|
||||
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
|
||||
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
|
||||
|
@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
|
|||
assert.Nil(t, psd)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) {
|
|||
for multiResult.NextResult() {
|
||||
}
|
||||
err = multiResult.Close()
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
|
@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) {
|
|||
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
|||
assert.Equal(t, 0, rowCount)
|
||||
commandTag, err := result.Close()
|
||||
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
|
|||
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
assert.True(t, errors.Is(result.Err, context.Canceled))
|
||||
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(result.Err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
|||
assert.Equal(t, 0, rowCount)
|
||||
commandTag, err := result.Close()
|
||||
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
|
@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
|
|||
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
assert.True(t, errors.Is(result.Err, context.Canceled))
|
||||
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(result.Err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
|
|||
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) {
|
|||
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
||||
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.Equal(t, "conn busy", err.Error())
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
results, err := mrr.ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
err = pgConn.WaitForNotification(ctx)
|
||||
cancel()
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, pgconn.CommandTag(nil), res)
|
||||
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
|
@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
|
|||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
assert.Equal(t, pgconn.CommandTag(nil), res)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
|
@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
|
|||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
cancel()
|
||||
assert.Equal(t, int64(0), ct.RowsAffected())
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
|
|||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
assert.Equal(t, pgconn.CommandTag(nil), ct)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
|
|
Loading…
Reference in New Issue