From b3cc9aa8a7f02b4b18e1babb5f335b422fc41011 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 13:59:00 -0600 Subject: [PATCH] Add ensureReadyForQuery to pgconn --- pgconn/helper_test.go | 14 ++++++ pgconn/pgconn.go | 102 +++++++++++++++++++++++++++--------------- pgconn/pgconn_test.go | 28 ++++++++++++ 3 files changed, 109 insertions(+), 35 deletions(-) diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go index 8e7ca92f..1053310b 100644 --- a/pgconn/helper_test.go +++ b/pgconn/helper_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,3 +16,16 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { defer cancel() require.Nil(t, conn.Close(ctx)) } + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil) + cancel() + + require.Nil(t, err) + assert.Equal(t, 3, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) + assert.Equal(t, "2", string(result.Rows[1][0])) + assert.Equal(t, "3", string(result.Rows[2][0])) +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 6b6330dc..76836b9c 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -562,23 +562,28 @@ func (rr *PgResultReader) close() { // Flush sends the enqueued execs to the server. func (pgConn *PgConn) Flush(ctx context.Context) error { - defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() + err := pgConn.flush() + cleanup() + return preferContextOverNetTimeoutError(ctx, err) +} +// flush sends the enqueued execs to the server without handling a context. +func (pgConn *PgConn) flush() error { n, err := pgConn.conn.Write(pgConn.batchBuf) - if err != nil { - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - return preferContextOverNetTimeoutError(ctx, err) + if err != nil && n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true } - pgConn.pendingReadyForQueryCount += pgConn.batchCount - return nil + if err == nil { + pgConn.pendingReadyForQueryCount += pgConn.batchCount + } + + pgConn.resetBatch() + + return err } // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from @@ -646,13 +651,11 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - for pgConn.pendingReadyForQueryCount > 0 { - _, err := pgConn.ReceiveMessage() - if err != nil { - preferContextOverNetTimeoutError(ctx, err) - pgConn.Close(context.Background()) - return false - } + err := pgConn.ensureReadyForQuery() + if err != nil { + preferContextOverNetTimeoutError(ctx, err) + pgConn.Close(context.Background()) + return false } result, err := pgConn.Exec( @@ -667,6 +670,18 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { return true } +// ensureReadyForQuery reads until pendingReadyForQueryCount == 0. +func (pgConn *PgConn) ensureReadyForQuery() error { + for pgConn.pendingReadyForQueryCount > 0 { + _, err := pgConn.ReceiveMessage() + if err != nil { + return err + } + } + + return nil +} + func (pgConn *PgConn) resetBatch() { pgConn.batchCount = 0 if len(pgConn.batchBuf) > batchBufferSize { @@ -690,14 +705,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExec(sql) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } return pgConn.bufferLastResult(ctx) @@ -741,12 +761,17 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { return nil, err } @@ -762,12 +787,17 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { return nil, err } @@ -809,18 +839,20 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") - } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) + } pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 - err := pgConn.Flush(context.Background()) + err = pgConn.flush() if err != nil { return nil, preferContextOverNetTimeoutError(ctx, err) } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 98ec9664..e436d739 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -243,6 +243,8 @@ func TestConnExec(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { @@ -256,6 +258,8 @@ func TestConnExecEmpty(t *testing.T) { require.Nil(t, err) assert.Nil(t, result.CommandTag) assert.Equal(t, 0, len(result.Rows)) + + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { @@ -269,6 +273,8 @@ func TestConnExecMultipleQueries(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { @@ -286,6 +292,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) { } else { t.Errorf("unexpected error: %v", err) } + + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -302,6 +310,8 @@ func TestConnExecContextCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnExecParams(t *testing.T) { @@ -315,6 +325,8 @@ func TestConnExecParams(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -331,6 +343,8 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnExecPrepared(t *testing.T) { @@ -350,6 +364,8 @@ func TestConnExecPrepared(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -369,6 +385,8 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnBatchedQueries(t *testing.T) { @@ -480,6 +498,8 @@ func TestConnBatchedQueries(t *testing.T) { // Done require.False(t, pgConn.NextResult(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -504,6 +524,8 @@ func TestConnRecoverFromTimeout(t *testing.T) { assert.Equal(t, "1", string(result.Rows[0][0])) } cancel() + + ensureConnValid(t, pgConn) } func TestConnCancelQuery(t *testing.T) { @@ -527,6 +549,10 @@ func TestConnCancelQuery(t *testing.T) { } else { t.Errorf("expected pgconn.PgError got %v", err) } + + require.False(t, pgConn.NextResult(context.Background())) + + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -573,4 +599,6 @@ begin end$$;`) require.Nil(t, err) assert.Equal(t, "hello, world", msg) + + ensureConnValid(t, pgConn) }