From fccaebc93dba7a54a636d4af400bd3676629c401 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 16 Apr 2022 13:38:27 -0500
Subject: [PATCH] Add pgtype.Map.SQLScanner

This enables compatibility with database/sql for types that cannot
implement Scan themselves.
---
 pgtype/pgtype.go   | 38 +++++++++++++++++++++++++
 stdlib/sql.go      | 70 ++++++++++++++++++++++++++--------------------
 stdlib/sql_test.go | 32 +++++++++++++++++++++
 3 files changed, 109 insertions(+), 31 deletions(-)

diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go
index db916220..ce06e738 100644
--- a/pgtype/pgtype.go
+++ b/pgtype/pgtype.go
@@ -1808,3 +1808,41 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu
 
 	return newBuf, nil
 }
+
+// SQLScanner returns a database/sql.Scanner for v. This is necessary for types like Array[T] and Range[T] where the
+// type needs assistance from Map to implement the sql.Scanner interface. It is not necessary for types like Box that
+// implement sql.Scanner directly.
+//
+// This uses the type of v to look up the PostgreSQL OID that v presumably came from. This means v must be registered
+// with m by calling RegisterDefaultPgType.
+func (m *Map) SQLScanner(v any) sql.Scanner {
+	if s, ok := v.(sql.Scanner); ok {
+		return s
+	}
+
+	return &sqlScannerWrapper{m: m, v: v}
+}
+
+type sqlScannerWrapper struct {
+	m *Map
+	v any
+}
+
+func (w *sqlScannerWrapper) Scan(src any) error {
+	t, ok := w.m.TypeForValue(w.v)
+	if !ok {
+		return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %T", w.v)
+	}
+
+	var bufSrc []byte
+	switch src := src.(type) {
+	case string:
+		bufSrc = []byte(src)
+	case []byte:
+		bufSrc = src
+	default:
+		bufSrc = []byte(fmt.Sprint(bufSrc))
+	}
+
+	return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v)
+}
diff --git a/stdlib/sql.go b/stdlib/sql.go
index 61fb77d3..e4c53ea7 100644
--- a/stdlib/sql.go
+++ b/stdlib/sql.go
@@ -2,50 +2,58 @@
 //
 // A database/sql connection can be established through sql.Open.
 //
-//	db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
-//	if err != nil {
-//		return err
-//	}
+//  db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
+//  if err != nil {
+//    return err
+//  }
 //
 // Or from a DSN string.
 //
-//	db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
-//	if err != nil {
-//		return err
-//	}
+//  db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
+//  if err != nil {
+//    return err
+//  }
 //
 // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
 // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
 // with sql.Open.
 //
-//	connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
-//	connConfig.Logger = myLogger
-//	connStr := stdlib.RegisterConnConfig(connConfig)
-//	db, _ := sql.Open("pgx", connStr)
+//  connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
+//  connConfig.Logger = myLogger
+//  connStr := stdlib.RegisterConnConfig(connConfig)
+//  db, _ := sql.Open("pgx", connStr)
 //
-// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
-// It does not support named parameters.
+// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters.
 //
-//	db.QueryRow("select * from users where id=$1", userID)
+//  db.QueryRow("select * from users where id=$1", userID)
 //
-// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard
-// database/sql.DB connection pool. This allows operations that use pgx specific functionality.
+// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection
+// pool. This allows operations that use pgx specific functionality.
 //
-//	// Given db is a *sql.DB
-//	conn, err := db.Conn(context.Background())
-//	if err != nil {
-//		// handle error from acquiring connection from DB pool
-//	}
+//  // Given db is a *sql.DB
+//  conn, err := db.Conn(context.Background())
+//  if err != nil {
+//    // handle error from acquiring connection from DB pool
+//  }
 //
-//	err = conn.Raw(func(driverConn any) error {
-//		conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn
-//		// Do pgx specific stuff with conn
-//		conn.CopyFrom(...)
-//		return nil
-//	})
-//	if err != nil {
-//		// handle error that occurred while using *pgx.Conn
-//	}
+//  err = conn.Raw(func(driverConn any) error {
+//    conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn
+//    // Do pgx specific stuff with conn
+//    conn.CopyFrom(...)
+//    return nil
+//  })
+//  if err != nil {
+//    // handle error that occurred while using *pgx.Conn
+//  }
+//
+// PostgreSQL Specific Data Types
+//
+// The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes
+// these types usable as a sql.Scanner.
+//
+//  m := pgtype.NewMap()
+//  var a []int64
+//  err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
 package stdlib
 
 import (
diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go
index 78b2d01f..75f0caf4 100644
--- a/stdlib/sql_test.go
+++ b/stdlib/sql_test.go
@@ -15,6 +15,7 @@ import (
 
 	"github.com/jackc/pgx/v5"
 	"github.com/jackc/pgx/v5/pgconn"
+	"github.com/jackc/pgx/v5/pgtype"
 	"github.com/jackc/pgx/v5/stdlib"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -373,6 +374,37 @@ func TestConnSimpleSlicePassThrough(t *testing.T) {
 	})
 }
 
+func TestConnQueryScanArray(t *testing.T) {
+	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
+		m := pgtype.NewMap()
+
+		var a []int64
+		err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
+		require.NoError(t, err)
+		assert.Equal(t, []int64{1, 2, 3}, a)
+	})
+}
+
+func TestConnQueryScanRange(t *testing.T) {
+	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
+		m := pgtype.NewMap()
+
+		var r pgtype.Range[pgtype.Int4]
+		err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
+		require.NoError(t, err)
+		assert.Equal(
+			t,
+			pgtype.Range[pgtype.Int4]{
+				Lower:     pgtype.Int4{Int32: 1, Valid: true},
+				Upper:     pgtype.Int4{Int32: 5, Valid: true},
+				LowerType: pgtype.Inclusive,
+				UpperType: pgtype.Exclusive,
+				Valid:     true,
+			},
+			r)
+	})
+}
+
 // Test type that pgx would handle natively in binary, but since it is not a
 // database/sql native type should be passed through as a string
 func TestConnQueryRowPgxBinary(t *testing.T) {