Add ensureReadyForQuery to pgconn

pull/483/head
Jack Christensen 2019-01-02 13:59:00 -06:00
parent 19a8df16b6
commit b3cc9aa8a7
3 changed files with 109 additions and 35 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgconn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -15,3 +16,16 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) {
defer cancel() defer cancel()
require.Nil(t, conn.Close(ctx)) require.Nil(t, conn.Close(ctx))
} }
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil)
cancel()
require.Nil(t, err)
assert.Equal(t, 3, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0]))
assert.Equal(t, "2", string(result.Rows[1][0]))
assert.Equal(t, "3", string(result.Rows[2][0]))
}

View File

@ -562,23 +562,28 @@ func (rr *PgResultReader) close() {
// Flush sends the enqueued execs to the server. // Flush sends the enqueued execs to the server.
func (pgConn *PgConn) Flush(ctx context.Context) error { func (pgConn *PgConn) Flush(ctx context.Context) error {
defer pgConn.resetBatch()
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup() err := pgConn.flush()
cleanup()
return preferContextOverNetTimeoutError(ctx, err)
}
// flush sends the enqueued execs to the server without handling a context.
func (pgConn *PgConn) flush() error {
n, err := pgConn.conn.Write(pgConn.batchBuf) n, err := pgConn.conn.Write(pgConn.batchBuf)
if err != nil { if err != nil && n > 0 {
if n > 0 { // Close connection because cannot recover from partially sent message.
// Close connection because cannot recover from partially sent message. pgConn.conn.Close()
pgConn.conn.Close() pgConn.closed = true
pgConn.closed = true
}
return preferContextOverNetTimeoutError(ctx, err)
} }
pgConn.pendingReadyForQueryCount += pgConn.batchCount if err == nil {
return nil pgConn.pendingReadyForQueryCount += pgConn.batchCount
}
pgConn.resetBatch()
return err
} }
// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from
@ -646,13 +651,11 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContext() defer cleanupContext()
for pgConn.pendingReadyForQueryCount > 0 { err := pgConn.ensureReadyForQuery()
_, err := pgConn.ReceiveMessage() if err != nil {
if err != nil { preferContextOverNetTimeoutError(ctx, err)
preferContextOverNetTimeoutError(ctx, err) pgConn.Close(context.Background())
pgConn.Close(context.Background()) return false
return false
}
} }
result, err := pgConn.Exec( result, err := pgConn.Exec(
@ -667,6 +670,18 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
return true return true
} }
// ensureReadyForQuery reads until pendingReadyForQueryCount == 0.
func (pgConn *PgConn) ensureReadyForQuery() error {
for pgConn.pendingReadyForQueryCount > 0 {
_, err := pgConn.ReceiveMessage()
if err != nil {
return err
}
}
return nil
}
func (pgConn *PgConn) resetBatch() { func (pgConn *PgConn) resetBatch() {
pgConn.batchCount = 0 pgConn.batchCount = 0
if len(pgConn.batchBuf) > batchBufferSize { if len(pgConn.batchBuf) > batchBufferSize {
@ -690,14 +705,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
if pgConn.batchCount != 0 { if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends") return nil, errors.New("unflushed previous sends")
} }
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results") cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
} }
pgConn.SendExec(sql) pgConn.SendExec(sql)
err := pgConn.Flush(ctx) err = pgConn.flush()
if err != nil { if err != nil {
return nil, err return nil, preferContextOverNetTimeoutError(ctx, err)
} }
return pgConn.bufferLastResult(ctx) return pgConn.bufferLastResult(ctx)
@ -741,12 +761,17 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
if pgConn.batchCount != 0 { if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends") return nil, errors.New("unflushed previous sends")
} }
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results") cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
} }
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
err := pgConn.Flush(ctx) err = pgConn.flush()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -762,12 +787,17 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
if pgConn.batchCount != 0 { if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends") return nil, errors.New("unflushed previous sends")
} }
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results") cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanup()
err := pgConn.ensureReadyForQuery()
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
} }
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
err := pgConn.Flush(ctx) err = pgConn.flush()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -809,18 +839,20 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
if pgConn.batchCount != 0 { if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends") return nil, errors.New("unflushed previous sends")
} }
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results")
}
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContext() defer cleanup()
err := pgConn.ensureReadyForQuery()
if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err)
}
pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf)
pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf)
pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf)
pgConn.batchCount += 1 pgConn.batchCount += 1
err := pgConn.Flush(context.Background()) err = pgConn.flush()
if err != nil { if err != nil {
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }

View File

@ -243,6 +243,8 @@ func TestConnExec(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
ensureConnValid(t, pgConn)
} }
func TestConnExecEmpty(t *testing.T) { func TestConnExecEmpty(t *testing.T) {
@ -256,6 +258,8 @@ func TestConnExecEmpty(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
assert.Nil(t, result.CommandTag) assert.Nil(t, result.CommandTag)
assert.Equal(t, 0, len(result.Rows)) assert.Equal(t, 0, len(result.Rows))
ensureConnValid(t, pgConn)
} }
func TestConnExecMultipleQueries(t *testing.T) { func TestConnExecMultipleQueries(t *testing.T) {
@ -269,6 +273,8 @@ func TestConnExecMultipleQueries(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "1", string(result.Rows[0][0]))
ensureConnValid(t, pgConn)
} }
func TestConnExecMultipleQueriesError(t *testing.T) { func TestConnExecMultipleQueriesError(t *testing.T) {
@ -286,6 +292,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) {
} else { } else {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
ensureConnValid(t, pgConn)
} }
func TestConnExecContextCanceled(t *testing.T) { func TestConnExecContextCanceled(t *testing.T) {
@ -302,6 +310,8 @@ func TestConnExecContextCanceled(t *testing.T) {
assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background())) assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
ensureConnValid(t, pgConn)
} }
func TestConnExecParams(t *testing.T) { func TestConnExecParams(t *testing.T) {
@ -315,6 +325,8 @@ func TestConnExecParams(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "Hello, world", string(result.Rows[0][0])) assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
ensureConnValid(t, pgConn)
} }
func TestConnExecParamsCanceled(t *testing.T) { func TestConnExecParamsCanceled(t *testing.T) {
@ -331,6 +343,8 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background())) assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
ensureConnValid(t, pgConn)
} }
func TestConnExecPrepared(t *testing.T) { func TestConnExecPrepared(t *testing.T) {
@ -350,6 +364,8 @@ func TestConnExecPrepared(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "Hello, world", string(result.Rows[0][0])) assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
ensureConnValid(t, pgConn)
} }
func TestConnExecPreparedCanceled(t *testing.T) { func TestConnExecPreparedCanceled(t *testing.T) {
@ -369,6 +385,8 @@ func TestConnExecPreparedCanceled(t *testing.T) {
assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background())) assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
ensureConnValid(t, pgConn)
} }
func TestConnBatchedQueries(t *testing.T) { func TestConnBatchedQueries(t *testing.T) {
@ -480,6 +498,8 @@ func TestConnBatchedQueries(t *testing.T) {
// Done // Done
require.False(t, pgConn.NextResult(context.Background())) require.False(t, pgConn.NextResult(context.Background()))
ensureConnValid(t, pgConn)
} }
func TestConnRecoverFromTimeout(t *testing.T) { func TestConnRecoverFromTimeout(t *testing.T) {
@ -504,6 +524,8 @@ func TestConnRecoverFromTimeout(t *testing.T) {
assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "1", string(result.Rows[0][0]))
} }
cancel() cancel()
ensureConnValid(t, pgConn)
} }
func TestConnCancelQuery(t *testing.T) { func TestConnCancelQuery(t *testing.T) {
@ -527,6 +549,10 @@ func TestConnCancelQuery(t *testing.T) {
} else { } else {
t.Errorf("expected pgconn.PgError got %v", err) t.Errorf("expected pgconn.PgError got %v", err)
} }
require.False(t, pgConn.NextResult(context.Background()))
ensureConnValid(t, pgConn)
} }
func TestCommandTag(t *testing.T) { func TestCommandTag(t *testing.T) {
@ -573,4 +599,6 @@ begin
end$$;`) end$$;`)
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, "hello, world", msg) assert.Equal(t, "hello, world", msg)
ensureConnValid(t, pgConn)
} }