mirror of https://github.com/jackc/pgx.git
Use Map.Encode path for simple protocol
parent
c4b08378f2
commit
fe21cc7486
6
conn.go
6
conn.go
|
@ -555,6 +555,9 @@ const (
|
|||
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
|
||||
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
|
||||
// a map[string]string directly as an argument. This mode cannot.
|
||||
//
|
||||
// It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g.
|
||||
// "SELECT $1::boolean".
|
||||
QueryExecModeExec
|
||||
|
||||
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
|
||||
|
@ -562,6 +565,9 @@ const (
|
|||
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
|
||||
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
|
||||
// a map[string]string directly as an argument. This mode cannot.
|
||||
//
|
||||
// This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to
|
||||
// specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean".
|
||||
QueryExecModeSimpleProtocol
|
||||
)
|
||||
|
||||
|
|
|
@ -1418,7 +1418,7 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||
var actual bool
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
"select $1::boolean",
|
||||
pgx.QueryExecModeSimpleProtocol,
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
|
@ -1733,7 +1733,7 @@ func TestConnSimpleProtocol(t *testing.T) {
|
|||
var actualString string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
|
||||
"select $1::int8, $2::float8, $3::boolean, $4::bytea, $5::text",
|
||||
pgx.QueryExecModeSimpleProtocol,
|
||||
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
|
||||
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
|
||||
|
|
127
values.go
127
values.go
|
@ -1,12 +1,6 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
@ -18,88 +12,19 @@ const (
|
|||
BinaryFormatCode = 1
|
||||
)
|
||||
|
||||
// SerializationError occurs on failure to encode or decode a value
|
||||
type SerializationError string
|
||||
|
||||
func (e SerializationError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) {
|
||||
if anynil.Is(arg) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if dv, ok := arg.(driver.Valuer); ok {
|
||||
return dv.Value()
|
||||
buf, err := m.Encode(0, TextFormatCode, arg, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change
|
||||
// the type of the argument. e.g. '42' instead of 42. So standard types are special cased.
|
||||
switch arg := arg.(type) {
|
||||
case float32:
|
||||
return float64(arg), nil
|
||||
case float64:
|
||||
return arg, nil
|
||||
case bool:
|
||||
return arg, nil
|
||||
case time.Duration:
|
||||
return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil
|
||||
case time.Time:
|
||||
return arg, nil
|
||||
case string:
|
||||
return arg, nil
|
||||
case []byte:
|
||||
return arg, nil
|
||||
case int8:
|
||||
return int64(arg), nil
|
||||
case int16:
|
||||
return int64(arg), nil
|
||||
case int32:
|
||||
return int64(arg), nil
|
||||
case int64:
|
||||
return arg, nil
|
||||
case int:
|
||||
return int64(arg), nil
|
||||
case uint8:
|
||||
return int64(arg), nil
|
||||
case uint16:
|
||||
return int64(arg), nil
|
||||
case uint32:
|
||||
return int64(arg), nil
|
||||
case uint64:
|
||||
if arg > math.MaxInt64 {
|
||||
return nil, fmt.Errorf("arg too big for int64: %v", arg)
|
||||
}
|
||||
return int64(arg), nil
|
||||
case uint:
|
||||
if uint64(arg) > math.MaxInt64 {
|
||||
return nil, fmt.Errorf("arg too big for int64: %v", arg)
|
||||
}
|
||||
return int64(arg), nil
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if _, found := m.TypeForValue(arg); found {
|
||||
buf, err := m.Encode(0, TextFormatCode, arg, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if buf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(arg)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
arg = refVal.Elem().Interface()
|
||||
return convertSimpleArgument(m, arg)
|
||||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return convertSimpleArgument(m, strippedArg)
|
||||
}
|
||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
|
||||
|
@ -119,43 +44,3 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]
|
|||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
||||
switch val.Kind() {
|
||||
case reflect.Int:
|
||||
convVal := int(val.Int())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Int8:
|
||||
convVal := int8(val.Int())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Int16:
|
||||
convVal := int16(val.Int())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Int32:
|
||||
convVal := int32(val.Int())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Int64:
|
||||
convVal := int64(val.Int())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Uint:
|
||||
convVal := uint(val.Uint())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Uint8:
|
||||
convVal := uint8(val.Uint())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Uint16:
|
||||
convVal := uint16(val.Uint())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Uint32:
|
||||
convVal := uint32(val.Uint())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.Uint64:
|
||||
convVal := uint64(val.Uint())
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
case reflect.String:
|
||||
convVal := val.String()
|
||||
return convVal, reflect.TypeOf(convVal) != val.Type()
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue