diff --git a/extended_query_builder.go b/extended_query_builder.go index 5d03790e..5409c0fd 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,9 +1,7 @@ package pgx import ( - "fmt" - "reflect" - + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/pgtype" ) @@ -55,14 +53,7 @@ func (eqb *extendedQueryBuilder) Reset() { } func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { - if arg == nil { - return nil, nil - } - - refVal := reflect.ValueOf(arg) - argIsPtr := refVal.Kind() == reflect.Ptr - - if argIsPtr && refVal.IsNil() { + if anynil.Is(arg) { return nil, nil } @@ -72,33 +63,15 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin pos := len(eqb.paramValueBytes) - if arg, ok := arg.(string); ok { - return []byte(arg), nil + buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) + if err != nil { + return nil, err } - - if argIsPtr { - // We have already checked that arg is not pointing to nil, - // so it is safe to dereference here. - arg = refVal.Elem().Interface() - return eqb.encodeExtendedParamValue(m, oid, formatCode, arg) + if buf == nil { + return nil, nil } - - if _, ok := m.TypeForOID(oid); ok { - buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return eqb.encodeExtendedParamValue(m, oid, formatCode, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil } // chooseParameterFormatCode determines the correct format code for an diff --git a/pgtype/json.go b/pgtype/json.go index e8882d3a..4d8cf4c4 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -16,13 +16,37 @@ func (JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { +func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan { switch value.(type) { + case string: + return encodePlanJSONCodecEitherFormatString{} case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} - default: - return encodePlanJSONCodecEitherFormatMarshal{} } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return encodePlanJSONCodecEitherFormatMarshal{} +} + +type encodePlanJSONCodecEitherFormatString struct{} + +func (encodePlanJSONCodecEitherFormatString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + jsonString := value.(string) + buf = append(buf, jsonString...) + return buf, nil } type encodePlanJSONCodecEitherFormatByteSlice struct{} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d1a92089..75934ced 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1155,6 +1155,14 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { + if format == TextFormatCode { + switch value.(type) { + case string: + return encodePlanStringToAnyTextFormat{} + case TextValuer: + return encodePlanTextValuerToAnyTextFormat{} + } + } var dt *Type @@ -1187,6 +1195,27 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan return nil } +type encodePlanStringToAnyTextFormat struct{} + +func (encodePlanStringToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + s := value.(string) + return append(buf, s...), nil +} + +type encodePlanTextValuerToAnyTextFormat struct{} + +func (encodePlanTextValuerToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + return append(buf, t.String...), nil +} + // TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan // that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted // by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it