mirror of https://github.com/jackc/pgx.git
Tag errors if no bytes sent to server
parent
0f8e1d30e2
commit
7e0022ef6b
|
@ -15,6 +15,10 @@ var ErrTLSRefused = errors.New("server refused TLS connection")
|
|||
// action is attempted.
|
||||
var ErrConnBusy = errors.New("conn is busy")
|
||||
|
||||
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used
|
||||
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error.
|
||||
var ErrNoBytesSent = errors.New("no bytes sent to server")
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
|
|
58
pgconn.go
58
pgconn.go
|
@ -444,13 +444,13 @@ type PreparedStatementDescription struct {
|
|||
// Prepare creates a prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, err
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
default:
|
||||
}
|
||||
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
|
||||
|
@ -461,9 +461,12 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
|
||||
buf = (&pgproto3.Sync{}).Encode(buf)
|
||||
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
}
|
||||
|
||||
|
@ -601,7 +604,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: err,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -614,7 +617,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
|
@ -624,11 +627,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
buf := pgConn.wbuf
|
||||
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
||||
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
pgConn.doneChanToDeadline.cleanup()
|
||||
multiResult.closed = true
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
multiResult.err = linkErrors(ctx.Err(), err)
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
|
@ -666,7 +672,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
|
||||
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
|
||||
|
||||
pgConn.execExtendedSuffix(buf, result)
|
||||
pgConn.execExtendedSuffix(ctx, buf, result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
@ -692,7 +698,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||
buf := pgConn.wbuf
|
||||
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
|
||||
|
||||
pgConn.execExtendedSuffix(buf, result)
|
||||
pgConn.execExtendedSuffix(ctx, buf, result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
@ -701,7 +707,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
|||
if err := pgConn.lock(); err != nil {
|
||||
return &ResultReader{
|
||||
closed: true,
|
||||
err: err,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -720,7 +726,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
|||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand(nil, ctx.Err())
|
||||
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent))
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
|
@ -731,15 +737,18 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
|||
return result
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
|
||||
func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) {
|
||||
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
|
||||
buf = (&pgproto3.Execute{}).Encode(buf)
|
||||
buf = (&pgproto3.Sync{}).Encode(buf)
|
||||
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
result.concludeCommand(nil, err)
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
|
||||
pgConn.doneChanToDeadline.cleanup()
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
|
@ -749,13 +758,13 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
|
|||
// CopyTo executes the copy command sql and copies the results to w.
|
||||
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, err
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pgConn.unlock()
|
||||
return nil, ctx.Err()
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
default:
|
||||
}
|
||||
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
|
||||
|
@ -765,11 +774,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
buf := pgConn.wbuf
|
||||
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
||||
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
pgConn.unlock()
|
||||
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
}
|
||||
|
||||
|
@ -808,13 +819,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
// could still block.
|
||||
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, err
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
default:
|
||||
}
|
||||
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
|
||||
|
@ -824,9 +835,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
buf := pgConn.wbuf
|
||||
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
||||
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
}
|
||||
|
||||
|
@ -1191,7 +1205,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: err,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1204,7 +1218,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
|
|
|
@ -264,9 +264,10 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
||||
require.Nil(t, psd)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.Nil(t, psd)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -386,8 +387,9 @@ func TestConnExecContextPrecanceled(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -492,7 +494,8 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
|
|||
cancel()
|
||||
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
require.Equal(t, context.Canceled, result.Err)
|
||||
assert.True(t, errors.Is(result.Err, context.Canceled))
|
||||
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -620,7 +623,8 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
|
|||
cancel()
|
||||
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
require.Equal(t, context.Canceled, result.Err)
|
||||
assert.True(t, errors.Is(result.Err, context.Canceled))
|
||||
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -677,7 +681,8 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
|
|||
cancel()
|
||||
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
@ -750,7 +755,8 @@ func TestConnLocking(t *testing.T) {
|
|||
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
||||
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, pgconn.ErrConnBusy, err)
|
||||
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
|
||||
results, err = mrr.ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
@ -1036,7 +1042,8 @@ func TestConnCopyToPrecanceled(t *testing.T) {
|
|||
cancel()
|
||||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.Equal(t, pgconn.CommandTag(nil), res)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
|
@ -1143,7 +1150,8 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
|
|||
cancel()
|
||||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, context.Canceled, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.Equal(t, pgconn.CommandTag(nil), ct)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
|
|
Loading…
Reference in New Issue