From b537f2c4126c124bc21a2fdb6369c2ffde4087bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 11:32:56 -0600 Subject: [PATCH] Add ExecPrepared --- pgconn/pgconn.go | 126 +++++++++++++++++++++++++++++++++++++++++- pgconn/pgconn_test.go | 57 +++++++++++++++++++ 2 files changed, 180 insertions(+), 3 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index db9c758d..f2e46539 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -387,6 +387,10 @@ func appendQuery(buf []byte, query string) []byte { // appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { + if len(paramOIDs) > 65535 { + panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs))) + } + buf = append(buf, 'P') sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -404,6 +408,19 @@ func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []by return buf } +// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. +func appendDescribe(buf []byte, objectType byte, name string) []byte { + buf = append(buf, 'D') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, objectType) + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + // appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. func appendSync(buf []byte) []byte { buf = append(buf, 'S') @@ -424,6 +441,9 @@ func appendBind( if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) } + if len(paramValues) > 65535 { + panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues))) + } buf = append(buf, 'B') sp := len(buf) @@ -492,9 +512,6 @@ func appendExecute(buf []byte, portal string, maxRows uint32) []byte { // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - if len(paramValues) > 65535 { - panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues))) - } if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) } @@ -506,6 +523,25 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs pgConn.batchCount += 1 } +// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Query is only sent to the PostgreSQL server when Flush is called. +func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) + pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 +} + type PgResultReader struct { pgConn *PgConn fieldDescriptions []pgproto3.FieldDescription @@ -840,6 +876,90 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result, nil } +// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and +// returns it. See SendExecPrepared for parameter descriptions. +// +// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) + err := pgConn.Flush(ctx) + if err != nil { + return nil, err + } + + resultReader := pgConn.GetResult(ctx) + if resultReader == nil { + return nil, errors.New("unexpected missing result") + } + + var result *PgResult + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + + return result, nil +} + +// Prepare creates a prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error { + if pgConn.batchCount != 0 { + return errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return errors.New("unread previous results") + } + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContext() + + pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs) + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 + err := pgConn.Flush(context.Background()) + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + for pgConn.pendingReadyForQueryCount > 0 { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + // TODO + case *pgproto3.RowDescription: + // TODO + case *pgproto3.ErrorResponse: + return errorResponseToPgError(msg) + } + } + + return nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index a765dc4c..35f5b536 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -320,6 +320,41 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } +func TestConnExecPrepared(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.Nil(t, err) + + result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "Hello, world", string(result.Rows[0][0])) +} + +func TestConnExecPreparedCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.Nil(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + assert.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) +} + func TestConnBatchedQueries(t *testing.T) { t.Parallel() @@ -327,8 +362,12 @@ func TestConnBatchedQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) + err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.Nil(t, err) + pgConn.SendExec("select 'SendExec 1'") pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) + pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil) pgConn.SendExec("select 'SendExec 2'") pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) err = pgConn.Flush(context.Background()) @@ -369,6 +408,24 @@ func TestConnBatchedQueries(t *testing.T) { assert.Equal(t, "SELECT 1", string(commandTag)) assert.Nil(t, err) + // "SendExecPrepared 1" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecPrepared 1", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + // "SendExec 2" resultReader = pgConn.GetResult(context.Background()) require.NotNil(t, resultReader)