Add ExecPrepared

pull/483/head
Jack Christensen 2019-01-01 11:32:56 -06:00
parent 421cfd5547
commit b537f2c412
2 changed files with 180 additions and 3 deletions

View File

@ -387,6 +387,10 @@ func appendQuery(buf []byte, query string) []byte {
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte {
if len(paramOIDs) > 65535 {
panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs)))
}
buf = append(buf, 'P')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
@ -404,6 +408,19 @@ func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []by
return buf
}
// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it.
func appendDescribe(buf []byte, objectType byte, name string) []byte {
buf = append(buf, 'D')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, objectType)
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
func appendSync(buf []byte) []byte {
buf = append(buf, 'S')
@ -424,6 +441,9 @@ func appendBind(
if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) {
panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats)))
}
if len(paramValues) > 65535 {
panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues)))
}
buf = append(buf, 'B')
sp := len(buf)
@ -492,9 +512,6 @@ func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
//
// Query is only sent to the PostgreSQL server when Flush is called.
func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if len(paramValues) > 65535 {
panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues)))
}
if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) {
panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs)))
}
@ -506,6 +523,25 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs
pgConn.batchCount += 1
}
// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol.
//
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
//
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if
// len(paramFormats) is not 0, 1, or len(paramValues).
//
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
// binary format. If resultFormats is nil all results will be in text protocol.
//
// Query is only sent to the PostgreSQL server when Flush is called.
func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats)
pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0)
pgConn.batchBuf = appendSync(pgConn.batchBuf)
pgConn.batchCount += 1
}
type PgResultReader struct {
pgConn *PgConn
fieldDescriptions []pgproto3.FieldDescription
@ -840,6 +876,90 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
return result, nil
}
// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and
// returns it. See SendExecPrepared for parameter descriptions.
//
// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec).
func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) {
if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends")
}
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results")
}
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
err := pgConn.Flush(ctx)
if err != nil {
return nil, err
}
resultReader := pgConn.GetResult(ctx)
if resultReader == nil {
return nil, errors.New("unexpected missing result")
}
var result *PgResult
rows := [][][]byte{}
for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values()))
copy(row, resultReader.Values())
rows = append(rows, row)
}
commandTag, err := resultReader.Close()
if err != nil {
return nil, err
}
result = &PgResult{
Rows: rows,
CommandTag: commandTag,
}
return result, nil
}
// Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error {
if pgConn.batchCount != 0 {
return errors.New("unflushed previous sends")
}
if pgConn.pendingReadyForQueryCount != 0 {
return errors.New("unread previous results")
}
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContext()
pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs)
pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name)
pgConn.batchBuf = appendSync(pgConn.batchBuf)
pgConn.batchCount += 1
err := pgConn.Flush(context.Background())
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
}
for pgConn.pendingReadyForQueryCount > 0 {
msg, err := pgConn.ReceiveMessage()
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
// TODO
case *pgproto3.RowDescription:
// TODO
case *pgproto3.ErrorResponse:
return errorResponseToPgError(msg)
}
}
return nil
}
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
Severity: msg.Severity,

View File

@ -320,6 +320,41 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
}
func TestConnExecPrepared(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.Nil(t, err)
result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
}
func TestConnExecPreparedCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
require.Nil(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
assert.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
}
func TestConnBatchedQueries(t *testing.T) {
t.Parallel()
@ -327,8 +362,12 @@ func TestConnBatchedQueries(t *testing.T) {
require.Nil(t, err)
defer closeConn(t, pgConn)
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.Nil(t, err)
pgConn.SendExec("select 'SendExec 1'")
pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil)
pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil)
pgConn.SendExec("select 'SendExec 2'")
pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil)
err = pgConn.Flush(context.Background())
@ -369,6 +408,24 @@ func TestConnBatchedQueries(t *testing.T) {
assert.Equal(t, "SELECT 1", string(commandTag))
assert.Nil(t, err)
// "SendExecPrepared 1"
resultReader = pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)
rows = [][][]byte{}
for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values()))
copy(row, resultReader.Values())
rows = append(rows, row)
}
require.Len(t, rows, 1)
require.Len(t, rows[0], 1)
assert.Equal(t, "SendExecPrepared 1", string(rows[0][0]))
commandTag, err = resultReader.Close()
assert.Equal(t, "SELECT 1", string(commandTag))
assert.Nil(t, err)
// "SendExec 2"
resultReader = pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)