diff --git a/pgtype/json.go b/pgtype/json.go index 48b9f977..76cec51b 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP // // https://github.com/jackc/pgx/issues/2146 func isSQLScanner(v any) bool { + if _, is := v.(sql.Scanner); is { + return true + } + val := reflect.ValueOf(v) for val.Kind() == reflect.Ptr { if _, ok := val.Interface().(sql.Scanner); ok { @@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - elem := reflect.ValueOf(dst).Elem() + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst) + } + + elem := v.Elem() elem.Set(reflect.Zero(elem.Type())) return s.unmarshal(src, dst) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 18ca5a8e..6277fc8b 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -48,6 +48,7 @@ func TestJSONCodec(t *testing.T) { Age int `json:"age"` } + var str string pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{ {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, @@ -65,6 +66,9 @@ func TestJSONCodec(t *testing.T) { {Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))}, // Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146) {Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))}, + + // Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204) + {NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }}, }) pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ @@ -136,6 +140,27 @@ func (i Issue2146) Value() (driver.Value, error) { return string(b), err } +type NonPointerJSONScanner struct { + V *string +} + +func (i NonPointerJSONScanner) Scan(src any) error { + switch c := src.(type) { + case string: + *i.V = c + case []byte: + *i.V = string(c) + default: + return errors.New("unknown source type") + } + + return nil +} + +func (i NonPointerJSONScanner) Value() (driver.Value, error) { + return i.V, nil +} + // https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 func TestJSONCodecUnmarshalSQLNull(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { @@ -267,7 +292,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) { Unmarshal: func(data []byte, v any) error { return json.Unmarshal([]byte(`{"custom":"value"}`), v) }, - }}) + }, + }) } pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ @@ -278,3 +304,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) { }}, }) } + +func TestJSONCodecScanToNonPointerValues(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + n := 44 + err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n) + require.Error(t, err) + + var i *int + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i) + require.Error(t, err) + + m := 0 + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m) + require.NoError(t, err) + require.Equal(t, 42, m) + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index f9d43edd..20645d69 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { // we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively func getSQLScanner(target any) sql.Scanner { + if sc, is := target.(sql.Scanner); is { + return sc + } + val := reflect.ValueOf(target) for val.Kind() == reflect.Ptr { if _, ok := val.Interface().(sql.Scanner); ok {