Tag errors if no bytes sent to server

query-exec-mode
Jack Christensen 2019-04-20 16:48:24 -05:00
parent 0f8e1d30e2
commit 7e0022ef6b
3 changed files with 59 additions and 33 deletions

View File

@ -15,6 +15,10 @@ var ErrTLSRefused = errors.New("server refused TLS connection")
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error.
var ErrNoBytesSent = errors.New("no bytes sent to server")
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.

View File

@ -444,13 +444,13 @@ type PreparedStatementDescription struct {
// Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, err
return nil, linkErrors(err, ErrNoBytesSent)
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
@ -461,9 +461,12 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
_, err := pgConn.conn.Write(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
}
@ -601,7 +604,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
err: linkErrors(err, ErrNoBytesSent),
}
}
@ -614,7 +617,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = ctx.Err()
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
pgConn.unlock()
return multiResult
default:
@ -624,11 +627,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
_, err := pgConn.conn.Write(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
multiResult.err = linkErrors(ctx.Err(), err)
pgConn.unlock()
return multiResult
@ -666,7 +672,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(buf, result)
pgConn.execExtendedSuffix(ctx, buf, result)
return result
}
@ -692,7 +698,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(buf, result)
pgConn.execExtendedSuffix(ctx, buf, result)
return result
}
@ -701,7 +707,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
if err := pgConn.lock(); err != nil {
return &ResultReader{
closed: true,
err: err,
err: linkErrors(err, ErrNoBytesSent),
}
}
@ -720,7 +726,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
select {
case <-ctx.Done():
result.concludeCommand(nil, ctx.Err())
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent))
result.closed = true
pgConn.unlock()
return result
@ -731,15 +737,18 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
}
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
_, err := pgConn.conn.Write(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
result.concludeCommand(nil, err)
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
pgConn.doneChanToDeadline.cleanup()
result.closed = true
pgConn.unlock()
@ -749,13 +758,13 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, err
return nil, linkErrors(err, ErrNoBytesSent)
}
select {
case <-ctx.Done():
pgConn.unlock()
return nil, ctx.Err()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
@ -765,11 +774,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
_, err := pgConn.conn.Write(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
pgConn.unlock()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
}
@ -808,13 +819,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, err
return nil, linkErrors(err, ErrNoBytesSent)
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
@ -824,9 +835,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
_, err := pgConn.conn.Write(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
}
@ -1191,7 +1205,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
err: linkErrors(err, ErrNoBytesSent),
}
}
@ -1204,7 +1218,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = ctx.Err()
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
pgConn.unlock()
return multiResult
default:

View File

@ -264,9 +264,10 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
require.Nil(t, psd)
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.Nil(t, psd)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
ensureConnValid(t, pgConn)
}
@ -386,8 +387,9 @@ func TestConnExecContextPrecanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
ensureConnValid(t, pgConn)
}
@ -492,7 +494,8 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
cancel()
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err)
require.Equal(t, context.Canceled, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
ensureConnValid(t, pgConn)
}
@ -620,7 +623,8 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
cancel()
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err)
require.Equal(t, context.Canceled, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
ensureConnValid(t, pgConn)
}
@ -677,7 +681,8 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
cancel()
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
ensureConnValid(t, pgConn)
}
@ -750,7 +755,8 @@ func TestConnLocking(t *testing.T) {
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.Equal(t, pgconn.ErrConnBusy, err)
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
results, err = mrr.ReadAll()
assert.NoError(t, err)
@ -1036,7 +1042,8 @@ func TestConnCopyToPrecanceled(t *testing.T) {
cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.Equal(t, pgconn.CommandTag(nil), res)
ensureConnValid(t, pgConn)
@ -1143,7 +1150,8 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
cancel()
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err)
require.Equal(t, context.Canceled, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.Equal(t, pgconn.CommandTag(nil), ct)
ensureConnValid(t, pgConn)