From 46966227bc5245d6ede511781e1adf71fcc6d926 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Mar 2022 10:04:02 -0600 Subject: [PATCH] Enable all QueryExecModes for exec path --- conn.go | 84 ++++++++++++++++++++++++++++++++++++++++------------ conn_test.go | 10 ++++++- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index c85bca88..54d128ab 100644 --- a/conn.go +++ b/conn.go @@ -405,48 +405,62 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { - simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol + mode := c.config.DefaultQueryExecMode optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { case QueryExecMode: - simpleProtocol = arg == QueryExecModeSimpleProtocol + mode = arg arguments = arguments[1:] default: break optionLoop } } + // Always use simple protocol when there are no arguments. + if len(arguments) == 0 { + mode = QueryExecModeSimpleProtocol + } + if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } - if simpleProtocol { - return c.execSimpleProtocol(ctx, sql, arguments) - } - - if len(arguments) == 0 { - return c.execSimpleProtocol(ctx, sql, arguments) - } - - if c.statementCache != nil { + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + } sd, err := c.statementCache.Get(ctx, sql) if err != nil { return pgconn.CommandTag{}, err } - if c.statementCache.Mode() == stmtcache.ModeDescribe { - return c.execParams(ctx, sd, arguments) + return c.execPrepared(ctx, sd, arguments) + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + } + sd, err := c.descriptionCache.Get(ctx, sql) + if err != nil { + return pgconn.CommandTag{}, err + } + + return c.execParams(ctx, sd, arguments) + case QueryExecModeDescribeExec: + sd, err := c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err } return c.execPrepared(ctx, sd, arguments) + case QueryExecModeExec: + return c.execSQLParams(ctx, sql, arguments) + case QueryExecModeSimpleProtocol: + return c.execSimpleProtocol(ctx, sql, arguments) + default: + return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode) } - - sd, err := c.Prepare(ctx, "", sql) - if err != nil { - return pgconn.CommandTag{}, err - } - return c.execPrepared(ctx, sd, arguments) } func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { @@ -510,6 +524,38 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return result.CommandTag, result.Err } +type unknownArgumentTypeQueryExecModeExecError struct { + arg interface{} +} + +func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { + return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) +} + +func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}) (pgconn.CommandTag, error) { + c.eqb.Reset() + + anynil.NormalizeSlice(args) + + paramOIDs := make([]uint32, len(args)) + + for i := range args { + dt, ok := c.TypeMap().TypeForValue(args[i]) + if !ok { + return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} + } + err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i]) + if err != nil { + return pgconn.CommandTag{}, err + } + paramOIDs[i] = dt.OID + } + + result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() + c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} + func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { r := &connRows{} diff --git a/conn_test.go b/conn_test.go index 625d9693..85b0da2b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -256,7 +256,15 @@ func TestExecFailureWithArguments(t *testing.T) { assert.False(t, pgconn.SafeToRetry(err)) _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") - require.Error(t, err) + if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec { + // The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it + // locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check + // for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing + // the SQL. + require.NoError(t, err) + } else { + require.Error(t, err) + } }) }