From 0d5d8e013747f33b6798d474ca6604ce3e1f5615 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Mon, 22 Aug 2022 20:06:42 -0500
Subject: [PATCH] Fallback to other format when encoding query arguments

The preferred format may not be possible for certain arguments. For
example, the preferred format for numeric is binary. But if
shopspring/decimal is being used without jackc/pgx-shopspring-decimal
then it will use the database/sql/driver.Valuer interface. This will
return a string. That string should be sent in the text format.

A similar case occurs when encoding a []string into a non-text
PostgreSQL array such as uuid[].
---
 extended_query_builder.go  | 23 +++++++++++++++++++++--
 pgtype/array_codec_test.go | 33 +++++++++++++++++++++++++++++++++
 query_test.go              | 22 +++++++++++++++++++++-
 3 files changed, 75 insertions(+), 3 deletions(-)

diff --git a/extended_query_builder.go b/extended_query_builder.go
index 1c47063c..b0c0e02b 100644
--- a/extended_query_builder.go
+++ b/extended_query_builder.go
@@ -51,14 +51,33 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
 // must be an untyped nil.
 func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
 	if format == -1 {
-		format = eqb.chooseParameterFormatCode(m, oid, arg)
+		preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
+		preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
+		if preferredErr == nil {
+			return nil
+		}
+
+		var otherFormat int16
+		if preferredFormat == TextFormatCode {
+			otherFormat = BinaryFormatCode
+		} else {
+			otherFormat = TextFormatCode
+		}
+
+		otherErr := eqb.appendParam(m, oid, otherFormat, arg)
+		if otherErr == nil {
+			return nil
+		}
+
+		return preferredErr // return the error from the preferred format
 	}
-	eqb.ParamFormats = append(eqb.ParamFormats, format)
 
 	v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
 	if err != nil {
 		return err
 	}
+
+	eqb.ParamFormats = append(eqb.ParamFormats, format)
 	eqb.ParamValues = append(eqb.ParamValues, v)
 
 	return nil
diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go
index 9da027e8..a558d0fc 100644
--- a/pgtype/array_codec_test.go
+++ b/pgtype/array_codec_test.go
@@ -2,6 +2,8 @@ package pgtype_test
 
 import (
 	"context"
+	"encoding/hex"
+	"strings"
 	"testing"
 
 	pgx "github.com/jackc/pgx/v5"
@@ -124,6 +126,37 @@ func TestArrayCodecAnySlice(t *testing.T) {
 	})
 }
 
+// https://github.com/jackc/pgx/issues/1273#issuecomment-1218262703
+func TestArrayCodecSliceArgConversion(t *testing.T) {
+	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
+		arg := []string{
+			"3ad95bfd-ecea-4032-83c3-0c823cafb372",
+			"951baf11-c0cc-4afc-a779-abff0611dbf1",
+			"8327f244-7e2f-45e7-a10b-fbdc9d6f3378",
+		}
+
+		var expected []pgtype.UUID
+
+		for _, s := range arg {
+			buf, err := hex.DecodeString(strings.ReplaceAll(s, "-", ""))
+			require.NoError(t, err)
+			var u pgtype.UUID
+			copy(u.Bytes[:], buf)
+			u.Valid = true
+			expected = append(expected, u)
+		}
+
+		var actual []pgtype.UUID
+		err := conn.QueryRow(
+			ctx,
+			"select $1::uuid[]",
+			arg,
+		).Scan(&actual)
+		require.NoError(t, err)
+		require.Equal(t, expected, actual)
+	})
+}
+
 func TestArrayCodecDecodeValue(t *testing.T) {
 	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
 		for _, tt := range []struct {
diff --git a/query_test.go b/query_test.go
index 720a1911..b2aa5d10 100644
--- a/query_test.go
+++ b/query_test.go
@@ -1165,7 +1165,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes
 	ensureConnValid(t, conn)
 }
 
-func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) {
+func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -1181,6 +1181,26 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *
 	ensureConnValid(t, conn)
 }
 
+// https://github.com/jackc/pgx/issues/1273#issuecomment-1221672175
+func TestConnQueryDatabaseSQLDriverValuerTextWhenBinaryIsPreferred(t *testing.T) {
+	t.Parallel()
+
+	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
+	defer closeConn(t, conn)
+
+	arg := sql.NullString{String: "1.234", Valid: true}
+	var result pgtype.Numeric
+	err := conn.QueryRow(context.Background(), "select $1::numeric", arg).Scan(&result)
+	require.NoError(t, err)
+
+	require.True(t, result.Valid)
+	f64, err := result.Float64Value()
+	require.NoError(t, err)
+	require.Equal(t, pgtype.Float8{Float64: 1.234, Valid: true}, f64)
+
+	ensureConnValid(t, conn)
+}
+
 func TestConnQueryDatabaseSQLNullX(t *testing.T) {
 	t.Parallel()