Simplify encoding extended query arguments

query-exec-mode
Jack Christensen 2022-03-05 20:16:57 -06:00
parent 1cef9075d9
commit e5685a34fc
3 changed files with 65 additions and 39 deletions

View File

@ -1,9 +1,7 @@
package pgx
import (
"fmt"
"reflect"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/pgtype"
)
@ -55,14 +53,7 @@ func (eqb *extendedQueryBuilder) Reset() {
}
func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
if arg == nil {
return nil, nil
}
refVal := reflect.ValueOf(arg)
argIsPtr := refVal.Kind() == reflect.Ptr
if argIsPtr && refVal.IsNil() {
if anynil.Is(arg) {
return nil, nil
}
@ -72,33 +63,15 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin
pos := len(eqb.paramValueBytes)
if arg, ok := arg.(string); ok {
return []byte(arg), nil
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if argIsPtr {
// We have already checked that arg is not pointing to nil,
// so it is safe to dereference here.
arg = refVal.Elem().Interface()
return eqb.encodeExtendedParamValue(m, oid, formatCode, arg)
if buf == nil {
return nil, nil
}
if _, ok := m.TypeForOID(oid); ok {
buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return eqb.encodeExtendedParamValue(m, oid, formatCode, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
// chooseParameterFormatCode determines the correct format code for an

View File

@ -16,13 +16,37 @@ func (JSONCodec) PreferredFormat() int16 {
return TextFormatCode
}
func (JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan {
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan {
switch value.(type) {
case string:
return encodePlanJSONCodecEitherFormatString{}
case []byte:
return encodePlanJSONCodecEitherFormatByteSlice{}
default:
return encodePlanJSONCodecEitherFormatMarshal{}
}
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
// appropriate wrappers here.
for _, f := range []TryWrapEncodePlanFunc{
TryWrapDerefPointerEncodePlan,
TryWrapFindUnderlyingTypeEncodePlan,
} {
if wrapperPlan, nextValue, ok := f(value); ok {
if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
}
}
return encodePlanJSONCodecEitherFormatMarshal{}
}
type encodePlanJSONCodecEitherFormatString struct{}
func (encodePlanJSONCodecEitherFormatString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
jsonString := value.(string)
buf = append(buf, jsonString...)
return buf, nil
}
type encodePlanJSONCodecEitherFormatByteSlice struct{}

View File

@ -1155,6 +1155,14 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src
// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be
// found then nil is returned.
func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan {
if format == TextFormatCode {
switch value.(type) {
case string:
return encodePlanStringToAnyTextFormat{}
case TextValuer:
return encodePlanTextValuerToAnyTextFormat{}
}
}
var dt *Type
@ -1187,6 +1195,27 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan
return nil
}
type encodePlanStringToAnyTextFormat struct{}
func (encodePlanStringToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
s := value.(string)
return append(buf, s...), nil
}
type encodePlanTextValuerToAnyTextFormat struct{}
func (encodePlanTextValuerToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
t, err := value.(TextValuer).TextValue()
if err != nil {
return nil, err
}
if !t.Valid {
return nil, nil
}
return append(buf, t.String...), nil
}
// TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan
// that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted
// by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it