mirror of https://github.com/jackc/pgx.git
Add ExecPrepared
parent
421cfd5547
commit
b537f2c412
126
pgconn/pgconn.go
126
pgconn/pgconn.go
|
@ -387,6 +387,10 @@ func appendQuery(buf []byte, query string) []byte {
|
|||
|
||||
// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it.
|
||||
func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte {
|
||||
if len(paramOIDs) > 65535 {
|
||||
panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs)))
|
||||
}
|
||||
|
||||
buf = append(buf, 'P')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
@ -404,6 +408,19 @@ func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []by
|
|||
return buf
|
||||
}
|
||||
|
||||
// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it.
|
||||
func appendDescribe(buf []byte, objectType byte, name string) []byte {
|
||||
buf = append(buf, 'D')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, objectType)
|
||||
buf = append(buf, name...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it.
|
||||
func appendSync(buf []byte) []byte {
|
||||
buf = append(buf, 'S')
|
||||
|
@ -424,6 +441,9 @@ func appendBind(
|
|||
if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) {
|
||||
panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats)))
|
||||
}
|
||||
if len(paramValues) > 65535 {
|
||||
panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues)))
|
||||
}
|
||||
|
||||
buf = append(buf, 'B')
|
||||
sp := len(buf)
|
||||
|
@ -492,9 +512,6 @@ func appendExecute(buf []byte, portal string, maxRows uint32) []byte {
|
|||
//
|
||||
// Query is only sent to the PostgreSQL server when Flush is called.
|
||||
func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
||||
if len(paramValues) > 65535 {
|
||||
panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues)))
|
||||
}
|
||||
if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) {
|
||||
panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs)))
|
||||
}
|
||||
|
@ -506,6 +523,25 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs
|
|||
pgConn.batchCount += 1
|
||||
}
|
||||
|
||||
// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol.
|
||||
//
|
||||
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
|
||||
//
|
||||
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
|
||||
// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if
|
||||
// len(paramFormats) is not 0, 1, or len(paramValues).
|
||||
//
|
||||
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
|
||||
// binary format. If resultFormats is nil all results will be in text protocol.
|
||||
//
|
||||
// Query is only sent to the PostgreSQL server when Flush is called.
|
||||
func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
||||
pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats)
|
||||
pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0)
|
||||
pgConn.batchBuf = appendSync(pgConn.batchBuf)
|
||||
pgConn.batchCount += 1
|
||||
}
|
||||
|
||||
type PgResultReader struct {
|
||||
pgConn *PgConn
|
||||
fieldDescriptions []pgproto3.FieldDescription
|
||||
|
@ -840,6 +876,90 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and
|
||||
// returns it. See SendExecPrepared for parameter descriptions.
|
||||
//
|
||||
// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec).
|
||||
func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) {
|
||||
if pgConn.batchCount != 0 {
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return nil, errors.New("unread previous results")
|
||||
}
|
||||
|
||||
pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats)
|
||||
err := pgConn.Flush(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resultReader := pgConn.GetResult(ctx)
|
||||
if resultReader == nil {
|
||||
return nil, errors.New("unexpected missing result")
|
||||
}
|
||||
|
||||
var result *PgResult
|
||||
rows := [][][]byte{}
|
||||
for resultReader.NextRow() {
|
||||
row := make([][]byte, len(resultReader.Values()))
|
||||
copy(row, resultReader.Values())
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
commandTag, err := resultReader.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = &PgResult{
|
||||
Rows: rows,
|
||||
CommandTag: commandTag,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error {
|
||||
if pgConn.batchCount != 0 {
|
||||
return errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return errors.New("unread previous results")
|
||||
}
|
||||
|
||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContext()
|
||||
|
||||
pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs)
|
||||
pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name)
|
||||
pgConn.batchBuf = appendSync(pgConn.batchBuf)
|
||||
pgConn.batchCount += 1
|
||||
err := pgConn.Flush(context.Background())
|
||||
if err != nil {
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
for pgConn.pendingReadyForQueryCount > 0 {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ParameterDescription:
|
||||
// TODO
|
||||
case *pgproto3.RowDescription:
|
||||
// TODO
|
||||
case *pgproto3.ErrorResponse:
|
||||
return errorResponseToPgError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
return &PgError{
|
||||
Severity: msg.Severity,
|
||||
|
|
|
@ -320,6 +320,41 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
|||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
}
|
||||
|
||||
func TestConnExecPrepared(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(result.Rows))
|
||||
assert.Equal(t, "Hello, world", string(result.Rows[0][0]))
|
||||
}
|
||||
|
||||
func TestConnExecPreparedCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
||||
assert.Nil(t, result)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
|
||||
}
|
||||
|
||||
func TestConnBatchedQueries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -327,8 +362,12 @@ func TestConnBatchedQueries(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
pgConn.SendExec("select 'SendExec 1'")
|
||||
pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil)
|
||||
pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil)
|
||||
pgConn.SendExec("select 'SendExec 2'")
|
||||
pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil)
|
||||
err = pgConn.Flush(context.Background())
|
||||
|
@ -369,6 +408,24 @@ func TestConnBatchedQueries(t *testing.T) {
|
|||
assert.Equal(t, "SELECT 1", string(commandTag))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// "SendExecPrepared 1"
|
||||
resultReader = pgConn.GetResult(context.Background())
|
||||
require.NotNil(t, resultReader)
|
||||
|
||||
rows = [][][]byte{}
|
||||
for resultReader.NextRow() {
|
||||
row := make([][]byte, len(resultReader.Values()))
|
||||
copy(row, resultReader.Values())
|
||||
rows = append(rows, row)
|
||||
}
|
||||
require.Len(t, rows, 1)
|
||||
require.Len(t, rows[0], 1)
|
||||
assert.Equal(t, "SendExecPrepared 1", string(rows[0][0]))
|
||||
|
||||
commandTag, err = resultReader.Close()
|
||||
assert.Equal(t, "SELECT 1", string(commandTag))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// "SendExec 2"
|
||||
resultReader = pgConn.GetResult(context.Background())
|
||||
require.NotNil(t, resultReader)
|
||||
|
|
Loading…
Reference in New Issue