Add QueryEx single round-trip mode

batch-wip
Jack Christensen 2017-05-29 11:27:44 -05:00
parent 85f30d10d2
commit dd5e6a77dc
2 changed files with 134 additions and 14 deletions

122
query.go
View File

@ -348,7 +348,12 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
}
type QueryExOptions struct {
ParameterOids []pgtype.Oid
// When ParameterOids are present and the query is not a prepared statement,
// then ParameterOids and ResultFormatCodes will be used to avoid an extra
// network round-trip.
ParameterOids []pgtype.Oid
ResultFormatCodes []int16
SimpleProtocol bool
}
@ -358,6 +363,10 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
return nil, err
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
c.lastActivityTime = time.Now()
rows = c.getRows(sql, args)
@ -368,13 +377,13 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
}
rows.unlockConn = true
if options != nil && options.SimpleProtocol {
err = c.initContext(ctx)
if err != nil {
rows.fatal(err)
return rows, err
}
err = c.initContext(ctx)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
if options != nil && options.SimpleProtocol {
err = c.sanitizeAndSendSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
@ -384,10 +393,54 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
return rows, nil
}
if options != nil && len(options.ParameterOids) > 0 {
buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
if err != nil {
rows.fatal(err)
return rows, err
}
buf = appendSync(buf)
n, err := c.conn.Write(buf)
if err != nil && fatalWriteErr(n, err) {
rows.fatal(err)
c.die(err)
return nil, err
}
c.readyForQuery = false
fieldDescriptions, err := c.readUntilRowDescription()
if err != nil {
rows.fatal(err)
return nil, err
}
if len(options.ResultFormatCodes) == 0 {
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = TextFormatCode
}
} else if len(options.ResultFormatCodes) == 1 {
fc := options.ResultFormatCodes[0]
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = fc
}
} else {
for i := range options.ResultFormatCodes {
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
}
}
rows.sql = sql
rows.fields = fieldDescriptions
return rows, nil
}
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.PrepareEx(ctx, "", sql, nil)
ps, err = c.prepareEx("", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
@ -396,12 +449,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
rows.sql = ps.SQL
rows.fields = ps.FieldDescriptions
err = c.initContext(ctx)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
err = c.sendPreparedQuery(ps, args...)
if err != nil {
rows.fatal(err)
@ -410,6 +457,53 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
return rows, rows.err
}
func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
if len(arguments) != len(options.ParameterOids) {
return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids))
}
if len(options.ParameterOids) > 65535 {
return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids))
}
buf = appendParse(buf, "", sql, options.ParameterOids)
buf = appendDescribe(buf, 'S', "")
buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, options.ResultFormatCodes)
if err != nil {
return nil, err
}
buf = appendExecute(buf, "", 0)
return buf, nil
}
func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) {
for {
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
case *pgproto3.RowDescription:
fieldDescriptions := c.rxRowDescription(msg)
for i := range fieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOid(fieldDescriptions[i].DataType); ok {
fieldDescriptions[i].DataTypeName = dt.Name
} else {
return nil, fmt.Errorf("unknown oid: %d", fieldDescriptions[i].DataType)
}
}
return fieldDescriptions, nil
default:
if err := c.processContextFreeMsg(msg); err != nil {
return nil, err
}
}
}
}
func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {
if c.RuntimeParams["standard_conforming_strings"] != "on" {
return errors.New("simple protocol queries must be run with standard_conforming_strings=on")

View File

@ -1182,6 +1182,32 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnQueryRowExSingleRoundTrip(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var result int32
err := conn.QueryRowEx(
context.Background(),
"select $1 + $2",
&pgx.QueryExOptions{
ParameterOids: []pgtype.Oid{pgtype.Int4Oid, pgtype.Int4Oid},
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
},
1, 2,
).Scan(&result)
if err != nil {
t.Fatal(err)
}
if result != 3 {
t.Fatal("result => %d, want %d", result, 3)
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocol(t *testing.T) {
t.Parallel()