mirror of https://github.com/jackc/pgx.git
Add ensureReadyForQuery to pgconn
parent
19a8df16b6
commit
b3cc9aa8a7
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
"github.com/jackc/pgx/pgconn"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,3 +16,16 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
require.Nil(t, conn.Close(ctx))
|
require.Nil(t, conn.Close(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do a simple query to ensure the connection is still usable
|
||||||
|
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, 3, len(result.Rows))
|
||||||
|
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||||
|
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||||
|
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||||
|
}
|
||||||
|
|
102
pgconn/pgconn.go
102
pgconn/pgconn.go
|
@ -562,23 +562,28 @@ func (rr *PgResultReader) close() {
|
||||||
|
|
||||||
// Flush sends the enqueued execs to the server.
|
// Flush sends the enqueued execs to the server.
|
||||||
func (pgConn *PgConn) Flush(ctx context.Context) error {
|
func (pgConn *PgConn) Flush(ctx context.Context) error {
|
||||||
defer pgConn.resetBatch()
|
|
||||||
|
|
||||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanup()
|
err := pgConn.flush()
|
||||||
|
cleanup()
|
||||||
|
return preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush sends the enqueued execs to the server without handling a context.
|
||||||
|
func (pgConn *PgConn) flush() error {
|
||||||
n, err := pgConn.conn.Write(pgConn.batchBuf)
|
n, err := pgConn.conn.Write(pgConn.batchBuf)
|
||||||
if err != nil {
|
if err != nil && n > 0 {
|
||||||
if n > 0 {
|
// Close connection because cannot recover from partially sent message.
|
||||||
// Close connection because cannot recover from partially sent message.
|
pgConn.conn.Close()
|
||||||
pgConn.conn.Close()
|
pgConn.closed = true
|
||||||
pgConn.closed = true
|
|
||||||
}
|
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.pendingReadyForQueryCount += pgConn.batchCount
|
if err == nil {
|
||||||
return nil
|
pgConn.pendingReadyForQueryCount += pgConn.batchCount
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn.resetBatch()
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
|
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
|
||||||
|
@ -646,13 +651,11 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
|
||||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContext()
|
defer cleanupContext()
|
||||||
|
|
||||||
for pgConn.pendingReadyForQueryCount > 0 {
|
err := pgConn.ensureReadyForQuery()
|
||||||
_, err := pgConn.ReceiveMessage()
|
if err != nil {
|
||||||
if err != nil {
|
preferContextOverNetTimeoutError(ctx, err)
|
||||||
preferContextOverNetTimeoutError(ctx, err)
|
pgConn.Close(context.Background())
|
||||||
pgConn.Close(context.Background())
|
return false
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := pgConn.Exec(
|
result, err := pgConn.Exec(
|
||||||
|
@ -667,6 +670,18 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureReadyForQuery reads until pendingReadyForQueryCount == 0.
|
||||||
|
func (pgConn *PgConn) ensureReadyForQuery() error {
|
||||||
|
for pgConn.pendingReadyForQueryCount > 0 {
|
||||||
|
_, err := pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) resetBatch() {
|
func (pgConn *PgConn) resetBatch() {
|
||||||
pgConn.batchCount = 0
|
pgConn.batchCount = 0
|
||||||
if len(pgConn.batchBuf) > batchBufferSize {
|
if len(pgConn.batchBuf) > batchBufferSize {
|
||||||
|
@ -690,14 +705,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
|
||||||
if pgConn.batchCount != 0 {
|
if pgConn.batchCount != 0 {
|
||||||
return nil, errors.New("unflushed previous sends")
|
return nil, errors.New("unflushed previous sends")
|
||||||
}
|
}
|
||||||
if pgConn.pendingReadyForQueryCount != 0 {
|
|
||||||
return nil, errors.New("unread previous results")
|
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
err := pgConn.ensureReadyForQuery()
|
||||||
|
if err != nil {
|
||||||
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.SendExec(sql)
|
pgConn.SendExec(sql)
|
||||||
err := pgConn.Flush(ctx)
|
err = pgConn.flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return pgConn.bufferLastResult(ctx)
|
return pgConn.bufferLastResult(ctx)
|
||||||
|
@ -741,12 +761,17 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
||||||
if pgConn.batchCount != 0 {
|
if pgConn.batchCount != 0 {
|
||||||
return nil, errors.New("unflushed previous sends")
|
return nil, errors.New("unflushed previous sends")
|
||||||
}
|
}
|
||||||
if pgConn.pendingReadyForQueryCount != 0 {
|
|
||||||
return nil, errors.New("unread previous results")
|
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
err := pgConn.ensureReadyForQuery()
|
||||||
|
if err != nil {
|
||||||
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
|
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
|
||||||
err := pgConn.Flush(ctx)
|
err = pgConn.flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -762,12 +787,17 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||||
if pgConn.batchCount != 0 {
|
if pgConn.batchCount != 0 {
|
||||||
return nil, errors.New("unflushed previous sends")
|
return nil, errors.New("unflushed previous sends")
|
||||||
}
|
}
|
||||||
if pgConn.pendingReadyForQueryCount != 0 {
|
|
||||||
return nil, errors.New("unread previous results")
|
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
err := pgConn.ensureReadyForQuery()
|
||||||
|
if err != nil {
|
||||||
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
|
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
|
||||||
err := pgConn.Flush(ctx)
|
err = pgConn.flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -809,18 +839,20 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
||||||
if pgConn.batchCount != 0 {
|
if pgConn.batchCount != 0 {
|
||||||
return nil, errors.New("unflushed previous sends")
|
return nil, errors.New("unflushed previous sends")
|
||||||
}
|
}
|
||||||
if pgConn.pendingReadyForQueryCount != 0 {
|
|
||||||
return nil, errors.New("unread previous results")
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContext()
|
defer cleanup()
|
||||||
|
|
||||||
|
err := pgConn.ensureReadyForQuery()
|
||||||
|
if err != nil {
|
||||||
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf)
|
pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf)
|
||||||
pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf)
|
pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf)
|
||||||
pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf)
|
pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf)
|
||||||
pgConn.batchCount += 1
|
pgConn.batchCount += 1
|
||||||
err := pgConn.Flush(context.Background())
|
err = pgConn.flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -243,6 +243,8 @@ func TestConnExec(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, 1, len(result.Rows))
|
assert.Equal(t, 1, len(result.Rows))
|
||||||
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
|
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecEmpty(t *testing.T) {
|
func TestConnExecEmpty(t *testing.T) {
|
||||||
|
@ -256,6 +258,8 @@ func TestConnExecEmpty(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Nil(t, result.CommandTag)
|
assert.Nil(t, result.CommandTag)
|
||||||
assert.Equal(t, 0, len(result.Rows))
|
assert.Equal(t, 0, len(result.Rows))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecMultipleQueries(t *testing.T) {
|
func TestConnExecMultipleQueries(t *testing.T) {
|
||||||
|
@ -269,6 +273,8 @@ func TestConnExecMultipleQueries(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, 1, len(result.Rows))
|
assert.Equal(t, 1, len(result.Rows))
|
||||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecMultipleQueriesError(t *testing.T) {
|
func TestConnExecMultipleQueriesError(t *testing.T) {
|
||||||
|
@ -286,6 +292,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecContextCanceled(t *testing.T) {
|
func TestConnExecContextCanceled(t *testing.T) {
|
||||||
|
@ -302,6 +310,8 @@ func TestConnExecContextCanceled(t *testing.T) {
|
||||||
assert.Equal(t, context.DeadlineExceeded, err)
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
|
||||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecParams(t *testing.T) {
|
func TestConnExecParams(t *testing.T) {
|
||||||
|
@ -315,6 +325,8 @@ func TestConnExecParams(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, 1, len(result.Rows))
|
assert.Equal(t, 1, len(result.Rows))
|
||||||
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecParamsCanceled(t *testing.T) {
|
func TestConnExecParamsCanceled(t *testing.T) {
|
||||||
|
@ -331,6 +343,8 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
||||||
assert.Equal(t, context.DeadlineExceeded, err)
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
|
||||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecPrepared(t *testing.T) {
|
func TestConnExecPrepared(t *testing.T) {
|
||||||
|
@ -350,6 +364,8 @@ func TestConnExecPrepared(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, 1, len(result.Rows))
|
assert.Equal(t, 1, len(result.Rows))
|
||||||
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnExecPreparedCanceled(t *testing.T) {
|
func TestConnExecPreparedCanceled(t *testing.T) {
|
||||||
|
@ -369,6 +385,8 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
||||||
assert.Equal(t, context.DeadlineExceeded, err)
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
|
||||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnBatchedQueries(t *testing.T) {
|
func TestConnBatchedQueries(t *testing.T) {
|
||||||
|
@ -480,6 +498,8 @@ func TestConnBatchedQueries(t *testing.T) {
|
||||||
|
|
||||||
// Done
|
// Done
|
||||||
require.False(t, pgConn.NextResult(context.Background()))
|
require.False(t, pgConn.NextResult(context.Background()))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnRecoverFromTimeout(t *testing.T) {
|
func TestConnRecoverFromTimeout(t *testing.T) {
|
||||||
|
@ -504,6 +524,8 @@ func TestConnRecoverFromTimeout(t *testing.T) {
|
||||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnCancelQuery(t *testing.T) {
|
func TestConnCancelQuery(t *testing.T) {
|
||||||
|
@ -527,6 +549,10 @@ func TestConnCancelQuery(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
t.Errorf("expected pgconn.PgError got %v", err)
|
t.Errorf("expected pgconn.PgError got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
require.False(t, pgConn.NextResult(context.Background()))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommandTag(t *testing.T) {
|
func TestCommandTag(t *testing.T) {
|
||||||
|
@ -573,4 +599,6 @@ begin
|
||||||
end$$;`)
|
end$$;`)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
assert.Equal(t, "hello, world", msg)
|
assert.Equal(t, "hello, world", msg)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue