mirror of https://github.com/jackc/pgx.git
Added ExecParams
parent
8d2e1463ed
commit
6ac70533bf
171
pgconn/pgconn.go
171
pgconn/pgconn.go
|
@ -20,6 +20,12 @@ import (
|
|||
|
||||
const batchBufferSize = 4096
|
||||
|
||||
// PostgreSQL extended protocol format codes
|
||||
const (
|
||||
TextFormatCode = 0
|
||||
BinaryFormatCode = 1
|
||||
)
|
||||
|
||||
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
|
@ -379,6 +385,127 @@ func appendQuery(buf []byte, query string) []byte {
|
|||
return buf
|
||||
}
|
||||
|
||||
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
|
||||
func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte {
|
||||
buf = append(buf, 'P')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, query...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(paramOIDs)))
|
||||
for _, oid := range paramOIDs {
|
||||
buf = pgio.AppendUint32(buf, oid)
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
|
||||
func appendSync(buf []byte) []byte {
|
||||
buf = append(buf, 'S')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it.
|
||||
func appendBind(
|
||||
buf []byte,
|
||||
destinationPortal,
|
||||
preparedStatement string,
|
||||
paramFormats []int16,
|
||||
paramValues [][]byte,
|
||||
resultFormatCodes []int16,
|
||||
) []byte {
|
||||
if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) {
|
||||
panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats)))
|
||||
}
|
||||
|
||||
buf = append(buf, 'B')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, destinationPortal...)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, preparedStatement...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(paramFormats)))
|
||||
for _, f := range paramFormats {
|
||||
buf = pgio.AppendInt16(buf, f)
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(paramValues)))
|
||||
for _, p := range paramValues {
|
||||
if p == nil {
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt32(buf, int32(len(p)))
|
||||
buf = append(buf, p...)
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes)))
|
||||
for _, fc := range resultFormatCodes {
|
||||
buf = pgio.AppendInt16(buf, fc)
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it.
|
||||
func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
|
||||
buf = append(buf, 'E')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
buf = append(buf, portal...)
|
||||
buf = append(buf, 0)
|
||||
buf = pgio.AppendUint32(buf, maxRows)
|
||||
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol.
|
||||
//
|
||||
// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc.
|
||||
//
|
||||
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
|
||||
//
|
||||
// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for
|
||||
// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter.
|
||||
// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues).
|
||||
//
|
||||
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
|
||||
// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if
|
||||
// len(paramFormats) is not 0, 1, or len(paramValues).
|
||||
//
|
||||
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
|
||||
// binary format. If resultFormats is nil all results will be in text protocol.
|
||||
//
|
||||
// Query is only sent to the PostgreSQL server when Flush is called.
|
||||
func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
||||
if len(paramValues) > 65535 {
|
||||
panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues)))
|
||||
}
|
||||
if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) {
|
||||
panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs)))
|
||||
}
|
||||
|
||||
pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs)
|
||||
pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats)
|
||||
pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0)
|
||||
pgConn.batchBuf = appendSync(pgConn.batchBuf)
|
||||
pgConn.batchCount += 1
|
||||
}
|
||||
|
||||
type PgResultReader struct {
|
||||
pgConn *PgConn
|
||||
fieldDescriptions []pgproto3.FieldDescription
|
||||
|
@ -669,6 +796,50 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See
|
||||
// SendExecParams for parameter descriptions.
|
||||
//
|
||||
// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec).
|
||||
func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) {
|
||||
if pgConn.batchCount != 0 {
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return nil, errors.New("unread previous results")
|
||||
}
|
||||
|
||||
pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats)
|
||||
err := pgConn.Flush(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resultReader := pgConn.GetResult(ctx)
|
||||
if resultReader == nil {
|
||||
return nil, errors.New("unexpected missing result")
|
||||
}
|
||||
|
||||
var result *PgResult
|
||||
rows := [][][]byte{}
|
||||
for resultReader.NextRow() {
|
||||
row := make([][]byte, len(resultReader.Values()))
|
||||
copy(row, resultReader.Values())
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
commandTag, err := resultReader.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = &PgResult{
|
||||
Rows: rows,
|
||||
CommandTag: commandTag,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
return &PgError{
|
||||
Severity: msg.Severity,
|
||||
|
|
|
@ -285,7 +285,36 @@ func TestConnExecContextCanceled(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
|
||||
require.Nil(t, result)
|
||||
assert.Nil(t, result)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
}
|
||||
|
||||
func TestConnExecParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(result.Rows))
|
||||
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
||||
}
|
||||
|
||||
func TestConnExecParamsCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
|
||||
assert.Nil(t, result)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
|
|
Loading…
Reference in New Issue