stdlib matches native pgx scanning support

stdlib can now directly scan into anything pgx can scan such as Go
slices.

This requires the change to database/sql implemented by
https://github.com/golang/go/pull/67648.

If this PR is accepted it will most likely land in Go 1.24.
pull/2029/head
Jack Christensen 2024-05-25 10:48:18 -05:00
parent 24c0a5e8ff
commit cbc7e95690
3 changed files with 212 additions and 11 deletions

View File

@ -8,6 +8,8 @@ import (
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5/pgtype"
)
func getSelectRowsCounts(b *testing.B) []int64 {
@ -107,3 +109,52 @@ func BenchmarkSelectRowsScanNull(b *testing.B) {
})
}
}
func BenchmarkFlatArrayEncodeArgument(b *testing.B) {
db := openDB(b)
defer closeDB(b, db)
input := make(pgtype.FlatArray[string], 10)
for i := range input {
input[i] = fmt.Sprintf("String %d", i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var n int64
err := db.QueryRow("select cardinality($1::text[])", input).Scan(&n)
if err != nil {
b.Fatal(err)
}
if n != int64(len(input)) {
b.Fatalf("Expected %d, got %d", len(input), n)
}
}
}
func BenchmarkFlatArrayScanResult(b *testing.B) {
db := openDB(b)
defer closeDB(b, db)
var input string
for i := 0; i < 10; i++ {
if i > 0 {
input += ","
}
input += fmt.Sprintf(`'String %d'`, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var result pgtype.FlatArray[string]
err := db.QueryRow(fmt.Sprintf("select array[%s]::text[]", input)).Scan(&result)
if err != nil {
b.Fatal(err)
}
if len(result) != 10 {
b.Fatalf("Expected %d, got %d", len(result), 10)
}
}
}

View File

@ -847,6 +847,12 @@ func (r *Rows) Next(dest []driver.Value) error {
return nil
}
func (r *Rows) ScanColumn(index int, dest any) error {
m := r.conn.conn.TypeMap()
fd := r.rows.FieldDescriptions()[index]
return m.Scan(fd.DataTypeOID, fd.Format, r.rows.RawValues()[index], dest)
}
func valueToInterface(argsV []driver.Value) []any {
args := make([]any, 0, len(argsV))
for _, v := range argsV {

View File

@ -107,6 +107,32 @@ func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
}
}
func testWithKnownOIDQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
for _, mode := range []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec,
} {
t.Run(mode.String(),
func(t *testing.T) {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.DefaultQueryExecMode = mode
db := stdlib.OpenDB(*config)
defer func() {
err := db.Close()
require.NoError(t, err)
}()
f(t, db)
ensureDBValid(t, db)
},
)
}
}
// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
// cover broken connections.
func ensureDBValid(t testing.TB, db *sql.DB) {
@ -509,29 +535,99 @@ func TestConnQueryScanGoArray(t *testing.T) {
})
}
func TestConnQueryScanArray(t *testing.T) {
func TestGoArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
var names []string
var a pgtype.Array[int64]
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
require.Equal(t, []string{"John", "Jane"}, names)
err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
require.EqualValues(t, 2, n)
err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}
func TestConnQueryScanRange(t *testing.T) {
func TestGoArrayOfDriverValuer(t *testing.T) {
// Because []sql.NullString is not a registered type on the connection, it will only work with known OIDs.
testWithKnownOIDQueryExecModes(t, func(t *testing.T, db *sql.DB) {
var names []sql.NullString
err := db.QueryRow("select array['John', null, 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
require.Equal(t, []sql.NullString{{String: "John", Valid: true}, {}, {String: "Jane", Valid: true}}, names)
var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}
func TestPGTypeFlatArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
var names pgtype.FlatArray[string]
err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
require.Equal(t, pgtype.FlatArray[string]{"John", "Jane"}, names)
var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}
func TestPGTypeArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support nested arrays")
var matrix pgtype.Array[int64]
err := db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[]").Scan(&matrix)
require.NoError(t, err)
require.Equal(t,
pgtype.Array[int64]{
Elements: []int64{1, 2, 3, 4, 5, 6},
Dims: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 1},
{Length: 3, LowerBound: 1},
},
Valid: true},
matrix)
var equal bool
err = db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[] = $1::bigint[]", matrix).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)
err = db.QueryRow("select null::bigint[]").Scan(&matrix)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, matrix)
})
}
func TestConnQueryPGTypeRange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")
m := pgtype.NewMap()
var r pgtype.Range[pgtype.Int4]
err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
err := db.QueryRow("select int4range(1, 5)").Scan(&r)
require.NoError(t, err)
assert.Equal(
t,
@ -543,6 +639,54 @@ func TestConnQueryScanRange(t *testing.T) {
Valid: true,
},
r)
var equal bool
err = db.QueryRow("select int4range(1, 5) = $1::int4range", r).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)
err = db.QueryRow("select null::int4range").Scan(&r)
require.NoError(t, err)
assert.Equal(t, pgtype.Range[pgtype.Int4]{}, r)
})
}
func TestConnQueryPGTypeMultirange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")
skipPostgreSQLVersionLessThan(t, db, 14)
var r pgtype.Multirange[pgtype.Range[pgtype.Int4]]
err := db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9))").Scan(&r)
require.NoError(t, err)
assert.Equal(
t,
pgtype.Multirange[pgtype.Range[pgtype.Int4]]{
{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
{
Lower: pgtype.Int4{Int32: 7, Valid: true},
Upper: pgtype.Int4{Int32: 9, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
},
r)
var equal bool
err = db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9)) = $1::int4multirange", r).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)
err = db.QueryRow("select null::int4multirange").Scan(&r)
require.NoError(t, err)
require.Nil(t, r)
})
}