Use Map.Encode path for simple protocol

query-exec-mode
Jack Christensen 2022-03-05 21:40:49 -06:00
parent c4b08378f2
commit fe21cc7486
3 changed files with 14 additions and 123 deletions

View File

@ -555,6 +555,9 @@ const (
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g.
// "SELECT $1::boolean".
QueryExecModeExec
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
@ -562,6 +565,9 @@ const (
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to
// specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean".
QueryExecModeSimpleProtocol
)

View File

@ -1418,7 +1418,7 @@ func TestConnSimpleProtocol(t *testing.T) {
var actual bool
err := conn.QueryRow(
context.Background(),
"select $1",
"select $1::boolean",
pgx.QueryExecModeSimpleProtocol,
expected,
).Scan(&actual)
@ -1733,7 +1733,7 @@ func TestConnSimpleProtocol(t *testing.T) {
var actualString string
err := conn.QueryRow(
context.Background(),
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
"select $1::int8, $2::float8, $3::boolean, $4::bytea, $5::text",
pgx.QueryExecModeSimpleProtocol,
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)

127
values.go
View File

@ -1,12 +1,6 @@
package pgx
import (
"database/sql/driver"
"fmt"
"math"
"reflect"
"time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgtype"
@ -18,88 +12,19 @@ const (
BinaryFormatCode = 1
)
// SerializationError occurs on failure to encode or decode a value
type SerializationError string
func (e SerializationError) Error() string {
return string(e)
}
func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) {
if anynil.Is(arg) {
return nil, nil
}
if dv, ok := arg.(driver.Valuer); ok {
return dv.Value()
buf, err := m.Encode(0, TextFormatCode, arg, nil)
if err != nil {
return nil, err
}
// All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change
// the type of the argument. e.g. '42' instead of 42. So standard types are special cased.
switch arg := arg.(type) {
case float32:
return float64(arg), nil
case float64:
return arg, nil
case bool:
return arg, nil
case time.Duration:
return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil
case time.Time:
return arg, nil
case string:
return arg, nil
case []byte:
return arg, nil
case int8:
return int64(arg), nil
case int16:
return int64(arg), nil
case int32:
return int64(arg), nil
case int64:
return arg, nil
case int:
return int64(arg), nil
case uint8:
return int64(arg), nil
case uint16:
return int64(arg), nil
case uint32:
return int64(arg), nil
case uint64:
if arg > math.MaxInt64 {
return nil, fmt.Errorf("arg too big for int64: %v", arg)
}
return int64(arg), nil
case uint:
if uint64(arg) > math.MaxInt64 {
return nil, fmt.Errorf("arg too big for int64: %v", arg)
}
return int64(arg), nil
if buf == nil {
return nil, nil
}
if _, found := m.TypeForValue(arg); found {
buf, err := m.Encode(0, TextFormatCode, arg, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
arg = refVal.Elem().Interface()
return convertSimpleArgument(m, arg)
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return convertSimpleArgument(m, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
return string(buf), nil
}
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
@ -119,43 +44,3 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]
}
return buf, nil
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {
switch val.Kind() {
case reflect.Int:
convVal := int(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int8:
convVal := int8(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int16:
convVal := int16(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int32:
convVal := int32(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int64:
convVal := int64(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint:
convVal := uint(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint8:
convVal := uint8(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint16:
convVal := uint16(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint32:
convVal := uint32(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint64:
convVal := uint64(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.String:
convVal := val.String()
return convVal, reflect.TypeOf(convVal) != val.Type()
}
return nil, false
}