Link context errors and underlying conn errors

Using golang.org/x/xerrors type errors both errors can be exposed.
query-exec-mode
Jack Christensen 2019-04-20 15:53:30 -05:00
parent f3b5f6b275
commit 0f8e1d30e2
3 changed files with 105 additions and 64 deletions

85
errors.go Normal file
View File

@ -0,0 +1,85 @@
package pgconn
import (
"context"
"net"
errors "golang.org/x/xerrors"
)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// linkedError connects two errors as if err wrapped next.
type linkedError struct {
err error
next error
}
func (le *linkedError) Error() string {
return le.err.Error()
}
func (le *linkedError) Is(target error) bool {
return errors.Is(le.err, target)
}
func (le *linkedError) As(target interface{}) bool {
return errors.As(le.err, target)
}
func (le *linkedError) Unwrap() error {
return le.next
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
func linkErrors(outer, inner error) error {
if outer == nil {
return inner
}
if inner == nil {
return outer
}
return &linkedError{err: outer, next: inner}
}

View File

@ -26,33 +26,6 @@ const (
connStatusBusy
)
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
// LISTEN/NOTIFY notification.
type Notice PgError
@ -79,14 +52,6 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event.
type NotificationHandler func(*PgConn, *Notification)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
@ -395,12 +360,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
return linkErrors(ctx.Err(), err)
}
_, err = pgConn.conn.Read(make([]byte, 1))
if err != io.EOF {
return preferContextOverNetTimeoutError(ctx, err)
return linkErrors(ctx.Err(), err)
}
return pgConn.conn.Close()
@ -469,15 +434,6 @@ func (ct CommandTag) String() string {
return string(ct)
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}
type PreparedStatementDescription struct {
Name string
SQL string
@ -508,7 +464,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
psd := &PreparedStatementDescription{Name: name, SQL: sql}
@ -520,7 +476,7 @@ readloop:
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
switch msg := msg.(type) {
@ -595,12 +551,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
return linkErrors(ctx.Err(), err)
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err))
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
}
return nil
@ -626,7 +582,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
return linkErrors(ctx.Err(), err)
}
switch msg.(type) {
@ -673,7 +629,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.hardClose()
pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
multiResult.err = linkErrors(ctx.Err(), err)
pgConn.unlock()
return multiResult
}
@ -814,7 +770,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
pgConn.hardClose()
pgConn.unlock()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
// Read results
@ -824,7 +780,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
switch msg := msg.(type) {
@ -871,7 +827,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
// Read until copy in response or error.
@ -882,7 +838,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
switch msg := msg.(type) {
@ -912,7 +868,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
}
@ -921,7 +877,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
switch msg := msg.(type) {
@ -943,7 +899,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
// Read results
@ -951,7 +907,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
return nil, linkErrors(ctx.Err(), err)
}
switch msg := msg.(type) {

View File

@ -18,7 +18,7 @@ import (
"time"
"github.com/jackc/pgconn"
"github.com/pkg/errors"
errors "golang.org/x/xerrors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -907,7 +907,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx)
cancel()
require.Equal(t, context.DeadlineExceeded, err)
assert.True(t, errors.Is(err, context.DeadlineExceeded))
ensureConnValid(t, pgConn)
}
@ -1017,7 +1017,7 @@ func TestConnCopyToCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Equal(t, pgconn.CommandTag(nil), res)
assert.False(t, pgConn.IsAlive())
@ -1108,7 +1108,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel()
assert.Equal(t, int64(0), ct.RowsAffected())
require.Equal(t, context.DeadlineExceeded, err)
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.False(t, pgConn.IsAlive())
}