mirror of https://github.com/jackc/pgx.git
Add ensureReadyForQuery to pgconn
parent
19a8df16b6
commit
b3cc9aa8a7
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -15,3 +16,16 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
|||
defer cancel()
|
||||
require.Nil(t, conn.Close(ctx))
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil)
|
||||
cancel()
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 3, len(result.Rows))
|
||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||
}
|
||||
|
|
|
@ -562,23 +562,28 @@ func (rr *PgResultReader) close() {
|
|||
|
||||
// Flush sends the enqueued execs to the server.
|
||||
func (pgConn *PgConn) Flush(ctx context.Context) error {
|
||||
defer pgConn.resetBatch()
|
||||
|
||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanup()
|
||||
err := pgConn.flush()
|
||||
cleanup()
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
// flush sends the enqueued execs to the server without handling a context.
|
||||
func (pgConn *PgConn) flush() error {
|
||||
n, err := pgConn.conn.Write(pgConn.batchBuf)
|
||||
if err != nil {
|
||||
if n > 0 {
|
||||
if err != nil && n > 0 {
|
||||
// Close connection because cannot recover from partially sent message.
|
||||
pgConn.conn.Close()
|
||||
pgConn.closed = true
|
||||
}
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
|
||||
if err == nil {
|
||||
pgConn.pendingReadyForQueryCount += pgConn.batchCount
|
||||
}
|
||||
|
||||
pgConn.pendingReadyForQueryCount += pgConn.batchCount
|
||||
return nil
|
||||
pgConn.resetBatch()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
|
||||
|
@ -646,14 +651,12 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
|
|||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContext()
|
||||
|
||||
for pgConn.pendingReadyForQueryCount > 0 {
|
||||
_, err := pgConn.ReceiveMessage()
|
||||
err := pgConn.ensureReadyForQuery()
|
||||
if err != nil {
|
||||
preferContextOverNetTimeoutError(ctx, err)
|
||||
pgConn.Close(context.Background())
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
result, err := pgConn.Exec(
|
||||
context.Background(), // do not use ctx again because deadline goroutine already started above
|
||||
|
@ -667,6 +670,18 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// ensureReadyForQuery reads until pendingReadyForQueryCount == 0.
|
||||
func (pgConn *PgConn) ensureReadyForQuery() error {
|
||||
for pgConn.pendingReadyForQueryCount > 0 {
|
||||
_, err := pgConn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) resetBatch() {
|
||||
pgConn.batchCount = 0
|
||||
if len(pgConn.batchBuf) > batchBufferSize {
|
||||
|
@ -690,14 +705,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
|
|||
if pgConn.batchCount != 0 {
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return nil, errors.New("unread previous results")
|
||||
|
||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanup()
|
||||
|
||||
err := pgConn.ensureReadyForQuery()
|
||||
if err != nil {
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
pgConn.SendExec(sql)
|
||||
err := pgConn.Flush(ctx)
|
||||
err = pgConn.flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
return pgConn.bufferLastResult(ctx)
|
||||
|
@ -741,12 +761,17 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||
if pgConn.batchCount != 0 {
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return nil, errors.New("unread previous results")
|
||||
|
||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanup()
|
||||
|
||||
err := pgConn.ensureReadyForQuery()
|
||||
if err != nil {
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
|
||||
err := pgConn.Flush(ctx)
|
||||
err = pgConn.flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -762,12 +787,17 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||
if pgConn.batchCount != 0 {
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return nil, errors.New("unread previous results")
|
||||
|
||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanup()
|
||||
|
||||
err := pgConn.ensureReadyForQuery()
|
||||
if err != nil {
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
|
||||
err := pgConn.Flush(ctx)
|
||||
err = pgConn.flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -809,18 +839,20 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||
if pgConn.batchCount != 0 {
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return nil, errors.New("unread previous results")
|
||||
}
|
||||
|
||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContext()
|
||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanup()
|
||||
|
||||
err := pgConn.ensureReadyForQuery()
|
||||
if err != nil {
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf)
|
||||
pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf)
|
||||
pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf)
|
||||
pgConn.batchCount += 1
|
||||
err := pgConn.Flush(context.Background())
|
||||
err = pgConn.flush()
|
||||
if err != nil {
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
|
|
@ -243,6 +243,8 @@ func TestConnExec(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(result.Rows))
|
||||
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecEmpty(t *testing.T) {
|
||||
|
@ -256,6 +258,8 @@ func TestConnExecEmpty(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
assert.Nil(t, result.CommandTag)
|
||||
assert.Equal(t, 0, len(result.Rows))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecMultipleQueries(t *testing.T) {
|
||||
|
@ -269,6 +273,8 @@ func TestConnExecMultipleQueries(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(result.Rows))
|
||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecMultipleQueriesError(t *testing.T) {
|
||||
|
@ -286,6 +292,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) {
|
|||
} else {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecContextCanceled(t *testing.T) {
|
||||
|
@ -302,6 +310,8 @@ func TestConnExecContextCanceled(t *testing.T) {
|
|||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParams(t *testing.T) {
|
||||
|
@ -315,6 +325,8 @@ func TestConnExecParams(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(result.Rows))
|
||||
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParamsCanceled(t *testing.T) {
|
||||
|
@ -331,6 +343,8 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
|||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecPrepared(t *testing.T) {
|
||||
|
@ -350,6 +364,8 @@ func TestConnExecPrepared(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(result.Rows))
|
||||
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecPreparedCanceled(t *testing.T) {
|
||||
|
@ -369,6 +385,8 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
|||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnBatchedQueries(t *testing.T) {
|
||||
|
@ -480,6 +498,8 @@ func TestConnBatchedQueries(t *testing.T) {
|
|||
|
||||
// Done
|
||||
require.False(t, pgConn.NextResult(context.Background()))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnRecoverFromTimeout(t *testing.T) {
|
||||
|
@ -504,6 +524,8 @@ func TestConnRecoverFromTimeout(t *testing.T) {
|
|||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||
}
|
||||
cancel()
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnCancelQuery(t *testing.T) {
|
||||
|
@ -527,6 +549,10 @@ func TestConnCancelQuery(t *testing.T) {
|
|||
} else {
|
||||
t.Errorf("expected pgconn.PgError got %v", err)
|
||||
}
|
||||
|
||||
require.False(t, pgConn.NextResult(context.Background()))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestCommandTag(t *testing.T) {
|
||||
|
@ -573,4 +599,6 @@ begin
|
|||
end$$;`)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "hello, world", msg)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue