From dd5de3e49e56174a232d7156c745fd3ea3a70d7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 09:11:52 -0500 Subject: [PATCH] Add single round-trip mode for ExecEx --- conn.go | 144 ++++++++++++++++++++++++++++++++++++--------------- conn_test.go | 41 +++++++++++++++ query.go | 1 + 3 files changed, 143 insertions(+), 43 deletions(-) diff --git a/conn.go b/conn.go index c4c054dd..6b0cc0c5 100644 --- a/conn.go +++ b/conn.go @@ -1428,82 +1428,82 @@ func (c *Conn) Ping(ctx context.Context) error { return err } -func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { - err = c.waitForPreviousCancelQuery(ctx) +func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { + err := c.waitForPreviousCancelQuery(ctx) if err != nil { return "", err } - if err = c.lock(); err != nil { - return commandTag, err + if err := c.lock(); err != nil { + return "", err } + defer c.unlock() startTime := time.Now() c.lastActivityTime = startTime - defer func() { - if err == nil { - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) - } - } else { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) - } + commandTag, err := c.execEx(ctx, sql, options, arguments...) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) } + return commandTag, err + } - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr - } + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + } + + return commandTag, err +} + +func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) }() if options != nil && options.SimpleProtocol { - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - err = c.sanitizeAndSendSimpleQuery(sql, arguments...) if err != nil { return "", err - } + } else if options != nil && len(options.ParameterOids) > 0 { + buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) + if err != nil { + return "", err + } + + // sync + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { + c.die(err) + return "", err + } + c.readyForQuery = false } else { if len(arguments) > 0 { ps, ok := c.preparedStatements[sql] if !ok { var err error - ps, err = c.PrepareEx(ctx, "", sql, nil) + ps, err = c.prepareEx("", sql, nil) if err != nil { return "", err } } - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - err = c.sendPreparedQuery(ps, arguments...) if err != nil { return "", err } } else { - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - if err = c.sendQuery(sql, arguments...); err != nil { return } @@ -1532,6 +1532,64 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, } } +func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { + if len(arguments) != len(options.ParameterOids) { + return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids)) + } + + if len(options.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) + } + + // parse + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 0) + buf = append(buf, sql...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) + for _, oid := range options.ParameterOids { + buf = pgio.AppendUint32(buf, uint32(oid)) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + // bind + buf = append(buf, 'B') + sp = len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 0) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) + for i, oid := range options.ParameterOids { + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) + } + + buf = pgio.AppendInt16(buf, int16(len(arguments))) + for i, oid := range options.ParameterOids { + var err error + buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) + if err != nil { + return nil, err + } + } + + // No result values for an exec + buf = pgio.AppendInt16(buf, 0) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + // execute + buf = append(buf, 'E') + buf = pgio.AppendInt32(buf, 9) + buf = append(buf, 0) + buf = pgio.AppendInt32(buf, 0) + + return buf, nil +} + func (c *Conn) initContext(ctx context.Context) error { if c.ctxInProgress { return errors.New("ctx already in progress") diff --git a/conn_test.go b/conn_test.go index acee1b49..4d001da5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1155,6 +1155,47 @@ func TestExecExSimpleProtocol(t *testing.T) { } } +func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + commandTag, err := conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.VarcharOid}}, + "bar'; drop table foo;--", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } +} + +func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + _, err := conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.Int4Oid}}, + "bar'; drop table foo;--", + ) + if err == nil { + t.Fatal("expected error but got none") + } +} + func TestPrepare(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 10eda1bc..0962b352 100644 --- a/query.go +++ b/query.go @@ -348,6 +348,7 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } type QueryExOptions struct { + ParameterOids []pgtype.Oid SimpleProtocol bool }