diff --git a/conn.go b/conn.go index 9f2fdcf0..34d23198 100644 --- a/conn.go +++ b/conn.go @@ -555,6 +555,9 @@ const ( // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. + // + // It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. + // "SELECT $1::boolean". QueryExecModeExec // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. @@ -562,6 +565,9 @@ const ( // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. + // + // This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to + // specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean". QueryExecModeSimpleProtocol ) diff --git a/query_test.go b/query_test.go index b6a0d65d..8ed89007 100644 --- a/query_test.go +++ b/query_test.go @@ -1418,7 +1418,7 @@ func TestConnSimpleProtocol(t *testing.T) { var actual bool err := conn.QueryRow( context.Background(), - "select $1", + "select $1::boolean", pgx.QueryExecModeSimpleProtocol, expected, ).Scan(&actual) @@ -1733,7 +1733,7 @@ func TestConnSimpleProtocol(t *testing.T) { var actualString string err := conn.QueryRow( context.Background(), - "select $1::int8, $2::float8, $3, $4::bytea, $5::text", + "select $1::int8, $2::float8, $3::boolean, $4::bytea, $5::text", pgx.QueryExecModeSimpleProtocol, expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) diff --git a/values.go b/values.go index a3343d81..595f2b4d 100644 --- a/values.go +++ b/values.go @@ -1,12 +1,6 @@ package pgx import ( - "database/sql/driver" - "fmt" - "math" - "reflect" - "time" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" @@ -18,88 +12,19 @@ const ( BinaryFormatCode = 1 ) -// SerializationError occurs on failure to encode or decode a value -type SerializationError string - -func (e SerializationError) Error() string { - return string(e) -} - func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) { if anynil.Is(arg) { return nil, nil } - if dv, ok := arg.(driver.Valuer); ok { - return dv.Value() + buf, err := m.Encode(0, TextFormatCode, arg, nil) + if err != nil { + return nil, err } - - // All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change - // the type of the argument. e.g. '42' instead of 42. So standard types are special cased. - switch arg := arg.(type) { - case float32: - return float64(arg), nil - case float64: - return arg, nil - case bool: - return arg, nil - case time.Duration: - return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil - case time.Time: - return arg, nil - case string: - return arg, nil - case []byte: - return arg, nil - case int8: - return int64(arg), nil - case int16: - return int64(arg), nil - case int32: - return int64(arg), nil - case int64: - return arg, nil - case int: - return int64(arg), nil - case uint8: - return int64(arg), nil - case uint16: - return int64(arg), nil - case uint32: - return int64(arg), nil - case uint64: - if arg > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil - case uint: - if uint64(arg) > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) - } - return int64(arg), nil + if buf == nil { + return nil, nil } - - if _, found := m.TypeForValue(arg); found { - buf, err := m.Encode(0, TextFormatCode, arg, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil - } - - refVal := reflect.ValueOf(arg) - if refVal.Kind() == reflect.Ptr { - arg = refVal.Elem().Interface() - return convertSimpleArgument(m, arg) - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return convertSimpleArgument(m, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) + return string(buf), nil } func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { @@ -119,43 +44,3 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([] } return buf, nil } - -func stripNamedType(val *reflect.Value) (interface{}, bool) { - switch val.Kind() { - case reflect.Int: - convVal := int(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int8: - convVal := int8(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int16: - convVal := int16(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int32: - convVal := int32(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Int64: - convVal := int64(val.Int()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint: - convVal := uint(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint8: - convVal := uint8(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint16: - convVal := uint16(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint32: - convVal := uint32(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.Uint64: - convVal := uint64(val.Uint()) - return convVal, reflect.TypeOf(convVal) != val.Type() - case reflect.String: - convVal := val.String() - return convVal, reflect.TypeOf(convVal) != val.Type() - } - - return nil, false -}