Add single round-trip mode for ExecEx

batch-wip
Jack Christensen 2017-05-29 09:11:52 -05:00
parent 2e2c2ad778
commit dd5de3e49e
3 changed files with 143 additions and 43 deletions

144
conn.go
View File

@ -1428,82 +1428,82 @@ func (c *Conn) Ping(ctx context.Context) error {
return err
}
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.waitForPreviousCancelQuery(ctx)
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
err := c.waitForPreviousCancelQuery(ctx)
if err != nil {
return "", err
}
if err = c.lock(); err != nil {
return commandTag, err
if err := c.lock(); err != nil {
return "", err
}
defer c.unlock()
startTime := time.Now()
c.lastActivityTime = startTime
defer func() {
if err == nil {
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
} else {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
commandTag, err := c.execEx(ctx, sql, options, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
}
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
if options != nil && options.SimpleProtocol {
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
if err != nil {
return "", err
}
} else if options != nil && len(options.ParameterOids) > 0 {
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
if err != nil {
return "", err
}
// sync
buf = append(buf, 'S')
buf = pgio.AppendInt32(buf, 4)
n, err := c.conn.Write(buf)
if err != nil && fatalWriteErr(n, err) {
c.die(err)
return "", err
}
c.readyForQuery = false
} else {
if len(arguments) > 0 {
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.PrepareEx(ctx, "", sql, nil)
ps, err = c.prepareEx("", sql, nil)
if err != nil {
return "", err
}
}
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
err = c.sendPreparedQuery(ps, arguments...)
if err != nil {
return "", err
}
} else {
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
@ -1532,6 +1532,64 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
}
}
func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
if len(arguments) != len(options.ParameterOids) {
return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids))
}
if len(options.ParameterOids) > 65535 {
return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids))
}
// parse
buf = append(buf, 'P')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 0)
buf = append(buf, sql...)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids)))
for _, oid := range options.ParameterOids {
buf = pgio.AppendUint32(buf, uint32(oid))
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// bind
buf = append(buf, 'B')
sp = len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 0)
buf = append(buf, 0)
buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids)))
for i, oid := range options.ParameterOids {
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
}
buf = pgio.AppendInt16(buf, int16(len(arguments)))
for i, oid := range options.ParameterOids {
var err error
buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i])
if err != nil {
return nil, err
}
}
// No result values for an exec
buf = pgio.AppendInt16(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// execute
buf = append(buf, 'E')
buf = pgio.AppendInt32(buf, 9)
buf = append(buf, 0)
buf = pgio.AppendInt32(buf, 0)
return buf, nil
}
func (c *Conn) initContext(ctx context.Context) error {
if c.ctxInProgress {
return errors.New("ctx already in progress")

View File

@ -1155,6 +1155,47 @@ func TestExecExSimpleProtocol(t *testing.T) {
}
}
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
commandTag, err := conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.VarcharOid}},
"bar'; drop table foo;--",
)
if err != nil {
t.Fatal(err)
}
if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
}
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
_, err := conn.ExecEx(
context.Background(),
"insert into foo(name) values($1);",
&pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.Int4Oid}},
"bar'; drop table foo;--",
)
if err == nil {
t.Fatal("expected error but got none")
}
}
func TestPrepare(t *testing.T) {
t.Parallel()

View File

@ -348,6 +348,7 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
}
type QueryExOptions struct {
ParameterOids []pgtype.Oid
SimpleProtocol bool
}