Added ExecParams

pull/483/head
Jack Christensen 2018-12-31 19:59:32 -06:00
parent 8d2e1463ed
commit 6ac70533bf
2 changed files with 201 additions and 1 deletions

View File

@ -20,6 +20,12 @@ import (
const batchBufferSize = 4096
// PostgreSQL extended protocol format codes
const (
TextFormatCode = 0
BinaryFormatCode = 1
)
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
// PgError represents an error reported by the PostgreSQL server. See
@ -379,6 +385,127 @@ func appendQuery(buf []byte, query string) []byte {
return buf
}
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte {
buf = append(buf, 'P')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, name...)
buf = append(buf, 0)
buf = append(buf, query...)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(paramOIDs)))
for _, oid := range paramOIDs {
buf = pgio.AppendUint32(buf, oid)
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
func appendSync(buf []byte) []byte {
buf = append(buf, 'S')
buf = pgio.AppendInt32(buf, 4)
return buf
}
// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it.
func appendBind(
buf []byte,
destinationPortal,
preparedStatement string,
paramFormats []int16,
paramValues [][]byte,
resultFormatCodes []int16,
) []byte {
if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) {
panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats)))
}
buf = append(buf, 'B')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, destinationPortal...)
buf = append(buf, 0)
buf = append(buf, preparedStatement...)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(paramFormats)))
for _, f := range paramFormats {
buf = pgio.AppendInt16(buf, f)
}
buf = pgio.AppendInt16(buf, int16(len(paramValues)))
for _, p := range paramValues {
if p == nil {
buf = pgio.AppendInt32(buf, -1)
continue
}
buf = pgio.AppendInt32(buf, int32(len(p)))
buf = append(buf, p...)
}
buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes)))
for _, fc := range resultFormatCodes {
buf = pgio.AppendInt16(buf, fc)
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it.
func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
buf = append(buf, 'E')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, portal...)
buf = append(buf, 0)
buf = pgio.AppendUint32(buf, maxRows)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return buf
}
// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol.
//
// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc.
//
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
//
// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for
// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter.
// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues).
//
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if
// len(paramFormats) is not 0, 1, or len(paramValues).
//
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
// binary format. If resultFormats is nil all results will be in text protocol.
//
// Query is only sent to the PostgreSQL server when Flush is called.
func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if len(paramValues) > 65535 {
panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues)))
}
if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) {
panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs)))
}
pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs)
pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats)
pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0)
pgConn.batchBuf = appendSync(pgConn.batchBuf)
pgConn.batchCount += 1
}
type PgResultReader struct {
pgConn *PgConn
fieldDescriptions []pgproto3.FieldDescription
@ -669,6 +796,50 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
return result, nil
}
// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See
// SendExecParams for parameter descriptions.
//
// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec).
func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) {
if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends")
}
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results")
}
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
err := pgConn.Flush(ctx)
if err != nil {
return nil, err
}
resultReader := pgConn.GetResult(ctx)
if resultReader == nil {
return nil, errors.New("unexpected missing result")
}
var result *PgResult
rows := [][][]byte{}
for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values()))
copy(row, resultReader.Values())
rows = append(rows, row)
}
commandTag, err := resultReader.Close()
if err != nil {
return nil, err
}
result = &PgResult{
Rows: rows,
CommandTag: commandTag,
}
return result, nil
}
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
Severity: msg.Severity,

View File

@ -285,7 +285,36 @@ func TestConnExecContextCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
require.Nil(t, result)
assert.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
}
func TestConnExecParams(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
}
func TestConnExecParamsCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
assert.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))