mirror of https://github.com/jackc/pgx.git
commit
c2175fe46e
|
@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/2146
|
// https://github.com/jackc/pgx/issues/2146
|
||||||
func isSQLScanner(v any) bool {
|
func isSQLScanner(v any) bool {
|
||||||
|
if _, is := v.(sql.Scanner); is {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
for val.Kind() == reflect.Ptr {
|
for val.Kind() == reflect.Ptr {
|
||||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||||
|
@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||||
return fmt.Errorf("cannot scan NULL into %T", dst)
|
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
elem := reflect.ValueOf(dst).Elem()
|
v := reflect.ValueOf(dst)
|
||||||
|
if v.Kind() != reflect.Pointer || v.IsNil() {
|
||||||
|
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := v.Elem()
|
||||||
elem.Set(reflect.Zero(elem.Type()))
|
elem.Set(reflect.Zero(elem.Type()))
|
||||||
|
|
||||||
return s.unmarshal(src, dst)
|
return s.unmarshal(src, dst)
|
||||||
|
|
|
@ -48,6 +48,7 @@ func TestJSONCodec(t *testing.T) {
|
||||||
Age int `json:"age"`
|
Age int `json:"age"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var str string
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
|
||||||
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
|
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
|
||||||
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
||||||
|
@ -65,6 +66,9 @@ func TestJSONCodec(t *testing.T) {
|
||||||
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
||||||
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
||||||
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
||||||
|
|
||||||
|
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
|
||||||
|
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
|
||||||
})
|
})
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -136,6 +140,27 @@ func (i Issue2146) Value() (driver.Value, error) {
|
||||||
return string(b), err
|
return string(b), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NonPointerJSONScanner struct {
|
||||||
|
V *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i NonPointerJSONScanner) Scan(src any) error {
|
||||||
|
switch c := src.(type) {
|
||||||
|
case string:
|
||||||
|
*i.V = c
|
||||||
|
case []byte:
|
||||||
|
*i.V = string(c)
|
||||||
|
default:
|
||||||
|
return errors.New("unknown source type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
|
||||||
|
return i.V, nil
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
||||||
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
@ -267,7 +292,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
Unmarshal: func(data []byte, v any) error {
|
Unmarshal: func(data []byte, v any) error {
|
||||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -278,3 +304,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
n := 44
|
||||||
|
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var i *int
|
||||||
|
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
m := 0
|
||||||
|
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||||
|
|
||||||
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
||||||
func getSQLScanner(target any) sql.Scanner {
|
func getSQLScanner(target any) sql.Scanner {
|
||||||
|
if sc, is := target.(sql.Scanner); is {
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
|
||||||
val := reflect.ValueOf(target)
|
val := reflect.ValueOf(target)
|
||||||
for val.Kind() == reflect.Ptr {
|
for val.Kind() == reflect.Ptr {
|
||||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||||
|
|
Loading…
Reference in New Issue