mirror of https://github.com/jackc/pgx.git
Add single round-trip mode for ExecEx
parent
2e2c2ad778
commit
dd5de3e49e
144
conn.go
144
conn.go
|
@ -1428,82 +1428,82 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
|
||||
err := c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err = c.lock(); err != nil {
|
||||
return commandTag, err
|
||||
if err := c.lock(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer c.unlock()
|
||||
|
||||
startTime := time.Now()
|
||||
c.lastActivityTime = startTime
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
}
|
||||
} else {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
||||
}
|
||||
commandTag, err := c.execEx(ctx, sql, options, arguments...)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
||||
}
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||
err = unlockErr
|
||||
}
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if options != nil && options.SimpleProtocol {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
}
|
||||
} else if options != nil && len(options.ParameterOids) > 0 {
|
||||
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// sync
|
||||
buf = append(buf, 'S')
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
n, err := c.conn.Write(buf)
|
||||
if err != nil && fatalWriteErr(n, err) {
|
||||
c.die(err)
|
||||
return "", err
|
||||
}
|
||||
c.readyForQuery = false
|
||||
} else {
|
||||
if len(arguments) > 0 {
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.PrepareEx(ctx, "", sql, nil)
|
||||
ps, err = c.prepareEx("", sql, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
err = c.sendPreparedQuery(ps, arguments...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -1532,6 +1532,64 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
|
||||
if len(arguments) != len(options.ParameterOids) {
|
||||
return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids))
|
||||
}
|
||||
|
||||
if len(options.ParameterOids) > 65535 {
|
||||
return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids))
|
||||
}
|
||||
|
||||
// parse
|
||||
buf = append(buf, 'P')
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, sql...)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids)))
|
||||
for _, oid := range options.ParameterOids {
|
||||
buf = pgio.AppendUint32(buf, uint32(oid))
|
||||
}
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// bind
|
||||
buf = append(buf, 'B')
|
||||
sp = len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, 0)
|
||||
buf = append(buf, 0)
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids)))
|
||||
for i, oid := range options.ParameterOids {
|
||||
buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(arguments)))
|
||||
for i, oid := range options.ParameterOids {
|
||||
var err error
|
||||
buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// No result values for an exec
|
||||
buf = pgio.AppendInt16(buf, 0)
|
||||
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
// execute
|
||||
buf = append(buf, 'E')
|
||||
buf = pgio.AppendInt32(buf, 9)
|
||||
buf = append(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *Conn) initContext(ctx context.Context) error {
|
||||
if c.ctxInProgress {
|
||||
return errors.New("ctx already in progress")
|
||||
|
|
41
conn_test.go
41
conn_test.go
|
@ -1155,6 +1155,47 @@ func TestExecExSimpleProtocol(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
commandTag, err := conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.VarcharOid}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "create temporary table foo(name varchar primary key);")
|
||||
|
||||
_, err := conn.ExecEx(
|
||||
context.Background(),
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.Int4Oid}},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue