mirror of https://github.com/jackc/pgx.git
Enable all QueryExecModes for exec path
parent
8e341e20f3
commit
46966227bc
84
conn.go
84
conn.go
|
@ -405,48 +405,62 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
}
|
||||
|
||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
switch arg := arguments[0].(type) {
|
||||
case QueryExecMode:
|
||||
simpleProtocol = arg == QueryExecModeSimpleProtocol
|
||||
mode = arg
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
// Always use simple protocol when there are no arguments.
|
||||
if len(arguments) == 0 {
|
||||
mode = QueryExecModeSimpleProtocol
|
||||
}
|
||||
|
||||
if sd, ok := c.preparedStatements[sql]; ok {
|
||||
return c.execPrepared(ctx, sd, arguments)
|
||||
}
|
||||
|
||||
if simpleProtocol {
|
||||
return c.execSimpleProtocol(ctx, sql, arguments)
|
||||
}
|
||||
|
||||
if len(arguments) == 0 {
|
||||
return c.execSimpleProtocol(ctx, sql, arguments)
|
||||
}
|
||||
|
||||
if c.statementCache != nil {
|
||||
switch mode {
|
||||
case QueryExecModeCacheStatement:
|
||||
if c.statementCache == nil {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||
}
|
||||
sd, err := c.statementCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
if c.statementCache.Mode() == stmtcache.ModeDescribe {
|
||||
return c.execParams(ctx, sd, arguments)
|
||||
return c.execPrepared(ctx, sd, arguments)
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||
}
|
||||
sd, err := c.descriptionCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
return c.execParams(ctx, sd, arguments)
|
||||
case QueryExecModeDescribeExec:
|
||||
sd, err := c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
return c.execPrepared(ctx, sd, arguments)
|
||||
case QueryExecModeExec:
|
||||
return c.execSQLParams(ctx, sql, arguments)
|
||||
case QueryExecModeSimpleProtocol:
|
||||
return c.execSimpleProtocol(ctx, sql, arguments)
|
||||
default:
|
||||
return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode)
|
||||
}
|
||||
|
||||
sd, err := c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
return c.execPrepared(ctx, sd, arguments)
|
||||
}
|
||||
|
||||
func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
|
@ -510,6 +524,38 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
|
|||
return result.CommandTag, result.Err
|
||||
}
|
||||
|
||||
type unknownArgumentTypeQueryExecModeExecError struct {
|
||||
arg interface{}
|
||||
}
|
||||
|
||||
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
||||
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
||||
}
|
||||
|
||||
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}) (pgconn.CommandTag, error) {
|
||||
c.eqb.Reset()
|
||||
|
||||
anynil.NormalizeSlice(args)
|
||||
|
||||
paramOIDs := make([]uint32, len(args))
|
||||
|
||||
for i := range args {
|
||||
dt, ok := c.TypeMap().TypeForValue(args[i])
|
||||
if !ok {
|
||||
return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
|
||||
}
|
||||
err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i])
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
paramOIDs[i] = dt.OID
|
||||
}
|
||||
|
||||
result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
|
||||
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
||||
return result.CommandTag, result.Err
|
||||
}
|
||||
|
||||
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
|
||||
r := &connRows{}
|
||||
|
||||
|
|
10
conn_test.go
10
conn_test.go
|
@ -256,7 +256,15 @@ func TestExecFailureWithArguments(t *testing.T) {
|
|||
assert.False(t, pgconn.SafeToRetry(err))
|
||||
|
||||
_, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2")
|
||||
require.Error(t, err)
|
||||
if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec {
|
||||
// The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it
|
||||
// locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check
|
||||
// for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing
|
||||
// the SQL.
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue