diff --git a/errors.go b/errors.go index e42dae16..4f8af407 100644 --- a/errors.go +++ b/errors.go @@ -15,6 +15,10 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // action is attempted. var ErrConnBusy = errors.New("conn is busy") +// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used +// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. +var ErrNoBytesSent = errors.New("no bytes sent to server") + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. diff --git a/pgconn.go b/pgconn.go index 2911211c..a4402a7d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -444,13 +444,13 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -461,9 +461,12 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -601,7 +604,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -614,7 +617,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = ctx.Err() + multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: @@ -624,11 +627,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.doneChanToDeadline.cleanup() multiResult.closed = true + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult @@ -666,7 +672,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(ctx, buf, result) return result } @@ -692,7 +698,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(ctx, buf, result) return result } @@ -701,7 +707,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if err := pgConn.lock(); err != nil { return &ResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -720,7 +726,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) result.closed = true pgConn.unlock() return result @@ -731,15 +737,18 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - result.concludeCommand(nil, err) + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + result.concludeCommand(nil, linkErrors(ctx.Err(), err)) pgConn.doneChanToDeadline.cleanup() result.closed = true pgConn.unlock() @@ -749,13 +758,13 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } select { case <-ctx.Done(): pgConn.unlock() - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -765,11 +774,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.unlock() - + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -808,13 +819,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -824,9 +835,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -1191,7 +1205,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -1204,7 +1218,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = ctx.Err() + multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 30e6a425..b7cb4036 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -264,9 +264,10 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) - require.Nil(t, psd) - require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.Nil(t, psd) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -386,8 +387,9 @@ func TestConnExecContextPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -492,7 +494,8 @@ func TestConnExecParamsPrecanceled(t *testing.T) { cancel() result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) - require.Equal(t, context.Canceled, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -620,7 +623,8 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { cancel() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) - require.Equal(t, context.Canceled, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -677,7 +681,8 @@ func TestConnExecBatchPrecanceled(t *testing.T) { cancel() _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -750,7 +755,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.Equal(t, pgconn.ErrConnBusy, err) + assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) results, err = mrr.ReadAll() assert.NoError(t, err) @@ -1036,7 +1042,8 @@ func TestConnCopyToPrecanceled(t *testing.T) { cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1143,7 +1150,8 @@ func TestConnCopyFromPrecanceled(t *testing.T) { cancel() ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn)