Query supports QueryExecMode

Fixed QueryExecModeExec as it must only use text format without
specifying param OIDs.
pull/1170/head
Jack Christensen 2022-03-12 14:15:39 -06:00
parent 0c166c7620
commit 1390a11fe2
4 changed files with 162 additions and 96 deletions

227
conn.go
View File

@ -98,6 +98,9 @@ var ErrNoRows = errors.New("no rows in result set")
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors.New("invalid log level")
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
func Connect(ctx context.Context, connString string) (*Conn, error) {
@ -430,7 +433,7 @@ optionLoop:
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
return pgconn.CommandTag{}, errDisabledStatementCache
}
sd, err := c.statementCache.Get(ctx, sql)
if err != nil {
@ -440,7 +443,7 @@ optionLoop:
return c.execPrepared(ctx, sd, arguments)
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
return pgconn.CommandTag{}, errDisabledDescriptionCache
}
sd, err := c.descriptionCache.Get(ctx, sql)
if err != nil {
@ -536,26 +539,51 @@ func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}
c.eqb.Reset()
anynil.NormalizeSlice(args)
paramOIDs := make([]uint32, len(args))
for i := range args {
dt, ok := c.TypeMap().TypeForValue(args[i])
if !ok {
return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
}
err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i])
if err != nil {
return pgconn.CommandTag{}, err
}
paramOIDs[i] = dt.OID
err := c.appendParamsForQueryExecModeExec(args)
if err != nil {
return pgconn.CommandTag{}, err
}
result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats).Read()
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
return result.CommandTag, result.Err
}
// appendParamsForQueryExecModeExec appends the args to c.eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error {
for i := range args {
if args[i] == nil {
err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, args[i])
if err != nil {
return err
}
} else {
dt, ok := c.TypeMap().TypeForValue(args[i])
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
}
err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, args[i])
if err != nil {
return err
}
}
}
return nil
}
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
r := &connRows{}
@ -589,14 +617,11 @@ const (
// when the the database schema is modified concurrently.
QueryExecModeDescribeExec
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended
// protocol. Queries are executed in a single round trip. Type mappings can be registered with
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g.
// "SELECT $1::boolean".
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
QueryExecModeExec
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
@ -605,8 +630,13 @@ const (
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to
// specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean".
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor
// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string.
//
// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol
// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does
// not support the extended protocol.
QueryExecModeSimpleProtocol
)
@ -640,13 +670,13 @@ type QueryResultFormatsByOID map[uint32]int16
// Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully
// as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row.
//
// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and
// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
mode := c.config.DefaultQueryExecMode
optionLoop:
for len(args) > 0 {
@ -658,19 +688,97 @@ optionLoop:
resultFormatsByOID = arg
args = args[1:]
case QueryExecMode:
simpleProtocol = arg == QueryExecModeSimpleProtocol
mode = arg
args = args[1:]
default:
break optionLoop
}
}
c.eqb.Reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
var err error
sd, ok := c.preparedStatements[sql]
sd := c.preparedStatements[sql]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
err = errDisabledStatementCache
rows.fatal(err)
return rows, err
}
sd, err = c.statementCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
err = errDisabledDescriptionCache
rows.fatal(err)
return rows, err
}
sd, err = c.descriptionCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
}
}
if simpleProtocol && !ok {
if len(sd.ParamOIDs) != len(args) {
rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
return rows, rows.err
}
rows.sql = sd.SQL
for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
if resultFormatsByOID != nil {
resultFormats = make([]int16, len(sd.Fields))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
}
}
if resultFormats == nil {
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
resultFormats = c.eqb.resultFormats
}
if mode == QueryExecModeCacheDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
}
} else if mode == QueryExecModeExec {
err := c.appendParamsForQueryExecModeExec(args)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats)
} else if mode == QueryExecModeSimpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
@ -688,61 +796,10 @@ optionLoop:
}
return rows, nil
}
c.eqb.Reset()
if !ok {
if c.statementCache != nil {
sd, err = c.statementCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
} else {
sd, err = c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
}
if len(sd.ParamOIDs) != len(args) {
rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
return rows, rows.err
}
rows.sql = sd.SQL
anynil.NormalizeSlice(args)
for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
if resultFormatsByOID != nil {
resultFormats = make([]int16, len(sd.Fields))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
}
}
if resultFormats == nil {
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
resultFormats = c.eqb.resultFormats
}
if c.statementCache != nil && c.statementCache.Mode() == stmtcache.ModeDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
err = fmt.Errorf("unknown QueryExecMode: %v", mode)
rows.fatal(err)
return rows, rows.err
}
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.

View File

@ -256,15 +256,7 @@ func TestExecFailureWithArguments(t *testing.T) {
assert.False(t, pgconn.SafeToRetry(err))
_, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2")
if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec {
// The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it
// locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check
// for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing
// the SQL.
require.NoError(t, err)
} else {
require.Error(t, err)
}
require.Error(t, err)
})
}

View File

@ -14,9 +14,13 @@ type extendedQueryBuilder struct {
func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error {
f := eqb.chooseParameterFormatCode(m, oid, arg)
eqb.paramFormats = append(eqb.paramFormats, f)
return eqb.AppendParamFormat(m, oid, f, arg)
}
v, err := eqb.encodeExtendedParamValue(m, oid, f, arg)
func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg interface{}) error {
eqb.paramFormats = append(eqb.paramFormats, format)
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil {
return err
}

View File

@ -891,6 +891,19 @@ func TestEncodeTypeRename(t *testing.T) {
inString := _string("foo")
var outString _string
// pgx.QueryExecModeExec requires all types to be registered.
conn.TypeMap().RegisterDefaultPgType(inInt, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt8, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt16, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt32, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt64, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint8, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint16, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint32, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint64, "int8")
conn.TypeMap().RegisterDefaultPgType(inString, "text")
err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)