diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index db916220..ce06e738 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1808,3 +1808,41 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu return newBuf, nil } + +// SQLScanner returns a database/sql.Scanner for v. This is necessary for types like Array[T] and Range[T] where the +// type needs assistance from Map to implement the sql.Scanner interface. It is not necessary for types like Box that +// implement sql.Scanner directly. +// +// This uses the type of v to look up the PostgreSQL OID that v presumably came from. This means v must be registered +// with m by calling RegisterDefaultPgType. +func (m *Map) SQLScanner(v any) sql.Scanner { + if s, ok := v.(sql.Scanner); ok { + return s + } + + return &sqlScannerWrapper{m: m, v: v} +} + +type sqlScannerWrapper struct { + m *Map + v any +} + +func (w *sqlScannerWrapper) Scan(src any) error { + t, ok := w.m.TypeForValue(w.v) + if !ok { + return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %T", w.v) + } + + var bufSrc []byte + switch src := src.(type) { + case string: + bufSrc = []byte(src) + case []byte: + bufSrc = src + default: + bufSrc = []byte(fmt.Sprint(bufSrc)) + } + + return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 61fb77d3..e4c53ea7 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -2,50 +2,58 @@ // // A database/sql connection can be established through sql.Open. // -// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") +// if err != nil { +// return err +// } // // Or from a DSN string. // -// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } // // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // -// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) -// connConfig.Logger = myLogger -// connStr := stdlib.RegisterConnConfig(connConfig) -// db, _ := sql.Open("pgx", connStr) +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Logger = myLogger +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) // -// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. -// It does not support named parameters. +// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // -// db.QueryRow("select * from users where id=$1", userID) +// db.QueryRow("select * from users where id=$1", userID) // -// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard -// database/sql.DB connection pool. This allows operations that use pgx specific functionality. +// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection +// pool. This allows operations that use pgx specific functionality. // -// // Given db is a *sql.DB -// conn, err := db.Conn(context.Background()) -// if err != nil { -// // handle error from acquiring connection from DB pool -// } +// // Given db is a *sql.DB +// conn, err := db.Conn(context.Background()) +// if err != nil { +// // handle error from acquiring connection from DB pool +// } // -// err = conn.Raw(func(driverConn any) error { -// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn -// // Do pgx specific stuff with conn -// conn.CopyFrom(...) -// return nil -// }) -// if err != nil { -// // handle error that occurred while using *pgx.Conn -// } +// err = conn.Raw(func(driverConn any) error { +// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn +// // Do pgx specific stuff with conn +// conn.CopyFrom(...) +// return nil +// }) +// if err != nil { +// // handle error that occurred while using *pgx.Conn +// } +// +// PostgreSQL Specific Data Types +// +// The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes +// these types usable as a sql.Scanner. +// +// m := pgtype.NewMap() +// var a []int64 +// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) package stdlib import ( diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 78b2d01f..75f0caf4 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -373,6 +374,37 @@ func TestConnSimpleSlicePassThrough(t *testing.T) { }) } +func TestConnQueryScanArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var a []int64 + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, []int64{1, 2, 3}, a) + }) +} + +func TestConnQueryScanRange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var r pgtype.Range[pgtype.Int4] + err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r)) + require.NoError(t, err) + assert.Equal( + t, + pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r) + }) +} + // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) {