pull/2213/head
Kostas Stamatakis 2024-12-30 22:43:04 +02:00
parent bcf3fbd780
commit 6e9fa42fef
No known key found for this signature in database
GPG Key ID: A0A88AF285F7E69B
6 changed files with 133 additions and 2 deletions

15
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Launch Package",
"type": "go",
"request": "launch",
"mode": "debug",
"program": "${fileDirname}"
}
]
}

5
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"go.testEnvVars": {
"PGX_TEST_DATABASE":"host=127.0.0.1 user=gamerhound password=gamerhound dbname=gamerhound"
}
}

View File

@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
//
// https://github.com/jackc/pgx/issues/2146
func isSQLScanner(v any) bool {
if _, is := v.(sql.Scanner); is {
return true
}
val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
return fmt.Errorf("cannot scan NULL into %T", dst)
}
elem := reflect.ValueOf(dst).Elem()
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Pointer || v.IsNil() {
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
}
elem := v.Elem()
elem.Set(reflect.Zero(elem.Type()))
return s.unmarshal(src, dst)

View File

@ -267,7 +267,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
},
})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@ -278,3 +279,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
}},
})
}
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
n := 44
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
require.Error(t, err)
var i *int
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
require.Error(t, err)
m := 0
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
require.NoError(t, err)
require.Equal(t, 42, m)
})
}

View File

@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
func getSQLScanner(target any) sql.Scanner {
if sc, is := target.(sql.Scanner); is {
return sc
}
val := reflect.ValueOf(target)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {

80
tete/main.go Normal file
View File

@ -0,0 +1,80 @@
package main
import (
"context"
"fmt"
"log"
"reflect"
"github.com/jackc/pgx/v5/pgxpool"
)
func main() {
pool, err := pgxpool.New(context.Background(), "postgres://gamerhound:gamerhound@localhost:5432/gamerhound")
if err != nil {
log.Fatal(err)
}
defer pool.Close()
// Create the enum type.
_, err = pool.Exec(context.Background(), `DROP TYPE IF EXISTS test_enum_type`)
if err != nil {
log.Print(err)
return
}
_, err = pool.Exec(context.Background(), `CREATE TYPE test_enum_type AS ENUM ('a', 'b')`)
if err != nil {
log.Print(err)
return
}
err = testQuery(pool, "SELECT 'a'", "a")
if err != nil {
log.Printf("test TEXT error: %s\n", err)
}
err = testQuery(pool, "SELECT 'a'::test_enum_type", "a")
if err != nil {
log.Printf("test ENUM error: %s\n", err)
}
err = testQuery(pool, "SELECT '{}'::jsonb", "{}")
if err != nil {
log.Printf("test JSONB error: %s\n", err)
}
}
// T implements the sql.Scanner interface.
type T struct {
v *any
}
func (t T) Scan(v any) error {
*t.v = v
return nil
}
// testQuery executes the query and checks if the scanned value matches
// the expected result.
func testQuery(pool *pgxpool.Pool, query string, expected any) error {
rows, err := pool.Query(context.Background(), query)
if err != nil {
return err
}
// defer rows.Close()
var got any
t := T{v: &got}
for rows.Next() {
if err := rows.Scan(t); err != nil {
return err
}
}
if err = rows.Err(); err != nil {
return err
}
if !reflect.DeepEqual(got, expected) {
return fmt.Errorf("expected %#v; got %#v", expected, got)
}
return nil
}