diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 87ba0096..db9c758d 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -20,6 +20,12 @@ import ( const batchBufferSize = 4096 +// PostgreSQL extended protocol format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See @@ -379,6 +385,127 @@ func appendQuery(buf []byte, query string) []byte { return buf } +// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. +func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, query...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(paramOIDs))) + for _, oid := range paramOIDs { + buf = pgio.AppendUint32(buf, oid) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. +func appendSync(buf []byte) []byte { + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + return buf +} + +// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. +func appendBind( + buf []byte, + destinationPortal, + preparedStatement string, + paramFormats []int16, + paramValues [][]byte, + resultFormatCodes []int16, +) []byte { + if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { + panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) + } + + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, destinationPortal...) + buf = append(buf, 0) + buf = append(buf, preparedStatement...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(paramFormats))) + for _, f := range paramFormats { + buf = pgio.AppendInt16(buf, f) + } + + buf = pgio.AppendInt16(buf, int16(len(paramValues))) + for _, p := range paramValues { + if p == nil { + buf = pgio.AppendInt32(buf, -1) + continue + } + + buf = pgio.AppendInt32(buf, int32(len(p))) + buf = append(buf, p...) + } + + buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) + for _, fc := range resultFormatCodes { + buf = pgio.AppendInt16(buf, fc) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. +func appendExecute(buf []byte, portal string, maxRows uint32) []byte { + buf = append(buf, 'E') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = append(buf, portal...) + buf = append(buf, 0) + buf = pgio.AppendUint32(buf, maxRows) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Query is only sent to the PostgreSQL server when Flush is called. +func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + if len(paramValues) > 65535 { + panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues))) + } + if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { + panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) + } + + pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) + pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) + pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 +} + type PgResultReader struct { pgConn *PgConn fieldDescriptions []pgproto3.FieldDescription @@ -669,6 +796,50 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return result, nil } +// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See +// SendExecParams for parameter descriptions. +// +// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) + err := pgConn.Flush(ctx) + if err != nil { + return nil, err + } + + resultReader := pgConn.GetResult(ctx) + if resultReader == nil { + return nil, errors.New("unexpected missing result") + } + + var result *PgResult + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + + return result, nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 05318dac..fa1ec5fc 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -285,7 +285,36 @@ func TestConnExecContextCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - require.Nil(t, result) + assert.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) +} + +func TestConnExecParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "Hello, world", string(result.Rows[0][0])) +} + +func TestConnExecParamsCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + assert.Nil(t, result) assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background()))