From e5685a34fc7f120b3479c719ed96e5b2d92e9221 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 5 Mar 2022 20:16:57 -0600
Subject: [PATCH] Simplify encoding extended query arguments

---
 extended_query_builder.go | 45 ++++++++-------------------------------
 pgtype/json.go            | 30 +++++++++++++++++++++++---
 pgtype/pgtype.go          | 29 +++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 39 deletions(-)

diff --git a/extended_query_builder.go b/extended_query_builder.go
index 5d03790e..5409c0fd 100644
--- a/extended_query_builder.go
+++ b/extended_query_builder.go
@@ -1,9 +1,7 @@
 package pgx
 
 import (
-	"fmt"
-	"reflect"
-
+	"github.com/jackc/pgx/v5/internal/anynil"
 	"github.com/jackc/pgx/v5/pgtype"
 )
 
@@ -55,14 +53,7 @@ func (eqb *extendedQueryBuilder) Reset() {
 }
 
 func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
-	if arg == nil {
-		return nil, nil
-	}
-
-	refVal := reflect.ValueOf(arg)
-	argIsPtr := refVal.Kind() == reflect.Ptr
-
-	if argIsPtr && refVal.IsNil() {
+	if anynil.Is(arg) {
 		return nil, nil
 	}
 
@@ -72,33 +63,15 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin
 
 	pos := len(eqb.paramValueBytes)
 
-	if arg, ok := arg.(string); ok {
-		return []byte(arg), nil
+	buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
+	if err != nil {
+		return nil, err
 	}
-
-	if argIsPtr {
-		// We have already checked that arg is not pointing to nil,
-		// so it is safe to dereference here.
-		arg = refVal.Elem().Interface()
-		return eqb.encodeExtendedParamValue(m, oid, formatCode, arg)
+	if buf == nil {
+		return nil, nil
 	}
-
-	if _, ok := m.TypeForOID(oid); ok {
-		buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
-		if err != nil {
-			return nil, err
-		}
-		if buf == nil {
-			return nil, nil
-		}
-		eqb.paramValueBytes = buf
-		return eqb.paramValueBytes[pos:], nil
-	}
-
-	if strippedArg, ok := stripNamedType(&refVal); ok {
-		return eqb.encodeExtendedParamValue(m, oid, formatCode, strippedArg)
-	}
-	return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
+	eqb.paramValueBytes = buf
+	return eqb.paramValueBytes[pos:], nil
 }
 
 // chooseParameterFormatCode determines the correct format code for an
diff --git a/pgtype/json.go b/pgtype/json.go
index e8882d3a..4d8cf4c4 100644
--- a/pgtype/json.go
+++ b/pgtype/json.go
@@ -16,13 +16,37 @@ func (JSONCodec) PreferredFormat() int16 {
 	return TextFormatCode
 }
 
-func (JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan {
+func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value interface{}) EncodePlan {
 	switch value.(type) {
+	case string:
+		return encodePlanJSONCodecEitherFormatString{}
 	case []byte:
 		return encodePlanJSONCodecEitherFormatByteSlice{}
-	default:
-		return encodePlanJSONCodecEitherFormatMarshal{}
 	}
+
+	// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
+	// appropriate wrappers here.
+	for _, f := range []TryWrapEncodePlanFunc{
+		TryWrapDerefPointerEncodePlan,
+		TryWrapFindUnderlyingTypeEncodePlan,
+	} {
+		if wrapperPlan, nextValue, ok := f(value); ok {
+			if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil {
+				wrapperPlan.SetNext(nextPlan)
+				return wrapperPlan
+			}
+		}
+	}
+
+	return encodePlanJSONCodecEitherFormatMarshal{}
+}
+
+type encodePlanJSONCodecEitherFormatString struct{}
+
+func (encodePlanJSONCodecEitherFormatString) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
+	jsonString := value.(string)
+	buf = append(buf, jsonString...)
+	return buf, nil
 }
 
 type encodePlanJSONCodecEitherFormatByteSlice struct{}
diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go
index d1a92089..75934ced 100644
--- a/pgtype/pgtype.go
+++ b/pgtype/pgtype.go
@@ -1155,6 +1155,14 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src
 // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be
 // found then nil is returned.
 func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan {
+	if format == TextFormatCode {
+		switch value.(type) {
+		case string:
+			return encodePlanStringToAnyTextFormat{}
+		case TextValuer:
+			return encodePlanTextValuerToAnyTextFormat{}
+		}
+	}
 
 	var dt *Type
 
@@ -1187,6 +1195,27 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan
 	return nil
 }
 
+type encodePlanStringToAnyTextFormat struct{}
+
+func (encodePlanStringToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
+	s := value.(string)
+	return append(buf, s...), nil
+}
+
+type encodePlanTextValuerToAnyTextFormat struct{}
+
+func (encodePlanTextValuerToAnyTextFormat) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
+	t, err := value.(TextValuer).TextValue()
+	if err != nil {
+		return nil, err
+	}
+	if !t.Valid {
+		return nil, nil
+	}
+
+	return append(buf, t.String...), nil
+}
+
 // TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan
 // that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted
 // by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it