mirror of https://github.com/jackc/pgx.git
fix #2204
parent
bcf3fbd780
commit
6e9fa42fef
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Launch Package",
|
||||||
|
"type": "go",
|
||||||
|
"request": "launch",
|
||||||
|
"mode": "debug",
|
||||||
|
"program": "${fileDirname}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"go.testEnvVars": {
|
||||||
|
"PGX_TEST_DATABASE":"host=127.0.0.1 user=gamerhound password=gamerhound dbname=gamerhound"
|
||||||
|
}
|
||||||
|
}
|
|
@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/2146
|
// https://github.com/jackc/pgx/issues/2146
|
||||||
func isSQLScanner(v any) bool {
|
func isSQLScanner(v any) bool {
|
||||||
|
if _, is := v.(sql.Scanner); is {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
for val.Kind() == reflect.Ptr {
|
for val.Kind() == reflect.Ptr {
|
||||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||||
|
@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||||
return fmt.Errorf("cannot scan NULL into %T", dst)
|
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
elem := reflect.ValueOf(dst).Elem()
|
v := reflect.ValueOf(dst)
|
||||||
|
if v.Kind() != reflect.Pointer || v.IsNil() {
|
||||||
|
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := v.Elem()
|
||||||
elem.Set(reflect.Zero(elem.Type()))
|
elem.Set(reflect.Zero(elem.Type()))
|
||||||
|
|
||||||
return s.unmarshal(src, dst)
|
return s.unmarshal(src, dst)
|
||||||
|
|
|
@ -267,7 +267,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
Unmarshal: func(data []byte, v any) error {
|
Unmarshal: func(data []byte, v any) error {
|
||||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -278,3 +279,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
n := 44
|
||||||
|
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var i *int
|
||||||
|
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
m := 0
|
||||||
|
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||||
|
|
||||||
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
||||||
func getSQLScanner(target any) sql.Scanner {
|
func getSQLScanner(target any) sql.Scanner {
|
||||||
|
if sc, is := target.(sql.Scanner); is {
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
|
||||||
val := reflect.ValueOf(target)
|
val := reflect.ValueOf(target)
|
||||||
for val.Kind() == reflect.Ptr {
|
for val.Kind() == reflect.Ptr {
|
||||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
pool, err := pgxpool.New(context.Background(), "postgres://gamerhound:gamerhound@localhost:5432/gamerhound")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
// Create the enum type.
|
||||||
|
_, err = pool.Exec(context.Background(), `DROP TYPE IF EXISTS test_enum_type`)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = pool.Exec(context.Background(), `CREATE TYPE test_enum_type AS ENUM ('a', 'b')`)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = testQuery(pool, "SELECT 'a'", "a")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("test TEXT error: %s\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = testQuery(pool, "SELECT 'a'::test_enum_type", "a")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("test ENUM error: %s\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = testQuery(pool, "SELECT '{}'::jsonb", "{}")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("test JSONB error: %s\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// T implements the sql.Scanner interface.
|
||||||
|
type T struct {
|
||||||
|
v *any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t T) Scan(v any) error {
|
||||||
|
*t.v = v
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testQuery executes the query and checks if the scanned value matches
|
||||||
|
// the expected result.
|
||||||
|
func testQuery(pool *pgxpool.Pool, query string, expected any) error {
|
||||||
|
rows, err := pool.Query(context.Background(), query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// defer rows.Close()
|
||||||
|
|
||||||
|
var got any
|
||||||
|
t := T{v: &got}
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(t); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, expected) {
|
||||||
|
return fmt.Errorf("expected %#v; got %#v", expected, got)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue