Enable all QueryExecModes for exec path

pull/1170/head
Jack Christensen 2022-03-12 10:04:02 -06:00
parent 8e341e20f3
commit 46966227bc
2 changed files with 74 additions and 20 deletions

84
conn.go
View File

@ -405,48 +405,62 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
mode := c.config.DefaultQueryExecMode
optionLoop:
for len(arguments) > 0 {
switch arg := arguments[0].(type) {
case QueryExecMode:
simpleProtocol = arg == QueryExecModeSimpleProtocol
mode = arg
arguments = arguments[1:]
default:
break optionLoop
}
}
// Always use simple protocol when there are no arguments.
if len(arguments) == 0 {
mode = QueryExecModeSimpleProtocol
}
if sd, ok := c.preparedStatements[sql]; ok {
return c.execPrepared(ctx, sd, arguments)
}
if simpleProtocol {
return c.execSimpleProtocol(ctx, sql, arguments)
}
if len(arguments) == 0 {
return c.execSimpleProtocol(ctx, sql, arguments)
}
if c.statementCache != nil {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
}
sd, err := c.statementCache.Get(ctx, sql)
if err != nil {
return pgconn.CommandTag{}, err
}
if c.statementCache.Mode() == stmtcache.ModeDescribe {
return c.execParams(ctx, sd, arguments)
return c.execPrepared(ctx, sd, arguments)
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
}
sd, err := c.descriptionCache.Get(ctx, sql)
if err != nil {
return pgconn.CommandTag{}, err
}
return c.execParams(ctx, sd, arguments)
case QueryExecModeDescribeExec:
sd, err := c.Prepare(ctx, "", sql)
if err != nil {
return pgconn.CommandTag{}, err
}
return c.execPrepared(ctx, sd, arguments)
case QueryExecModeExec:
return c.execSQLParams(ctx, sql, arguments)
case QueryExecModeSimpleProtocol:
return c.execSimpleProtocol(ctx, sql, arguments)
default:
return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode)
}
sd, err := c.Prepare(ctx, "", sql)
if err != nil {
return pgconn.CommandTag{}, err
}
return c.execPrepared(ctx, sd, arguments)
}
func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
@ -510,6 +524,38 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
return result.CommandTag, result.Err
}
type unknownArgumentTypeQueryExecModeExecError struct {
arg interface{}
}
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
}
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}) (pgconn.CommandTag, error) {
c.eqb.Reset()
anynil.NormalizeSlice(args)
paramOIDs := make([]uint32, len(args))
for i := range args {
dt, ok := c.TypeMap().TypeForValue(args[i])
if !ok {
return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
}
err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i])
if err != nil {
return pgconn.CommandTag{}, err
}
paramOIDs[i] = dt.OID
}
result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
return result.CommandTag, result.Err
}
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
r := &connRows{}

View File

@ -256,7 +256,15 @@ func TestExecFailureWithArguments(t *testing.T) {
assert.False(t, pgconn.SafeToRetry(err))
_, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2")
require.Error(t, err)
if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec {
// The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it
// locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check
// for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing
// the SQL.
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}