mirror of https://github.com/jackc/pgx.git
fix: #2146
[](https://www.meetup.com/it-IT/Open-Source-Saturday-Milano/) Co-authored-by: Alessio Izzo <alessio.izzo86@gmail.com>pull/2151/head
parent
2ec900454b
commit
5c9b565116
|
@ -143,10 +143,12 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
|
||||||
case BytesScanner:
|
case BytesScanner:
|
||||||
return scanPlanBinaryBytesToBytesScanner{}
|
return scanPlanBinaryBytesToBytesScanner{}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
|
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/1418
|
// https://github.com/jackc/pgx/issues/1418
|
||||||
case sql.Scanner:
|
if isSQLScanner(target) {
|
||||||
return &scanPlanSQLScanner{formatCode: format}
|
return &scanPlanSQLScanner{formatCode: format}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,6 +157,20 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
|
||||||
|
//
|
||||||
|
// https://github.com/jackc/pgx/issues/2146
|
||||||
|
func isSQLScanner(v any) bool {
|
||||||
|
val := reflect.ValueOf(v)
|
||||||
|
for val.Kind() == reflect.Ptr {
|
||||||
|
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type scanPlanAnyToString struct{}
|
type scanPlanAnyToString struct{}
|
||||||
|
|
||||||
func (scanPlanAnyToString) Scan(src []byte, dst any) error {
|
func (scanPlanAnyToString) Scan(src []byte, dst any) error {
|
||||||
|
|
|
@ -63,6 +63,8 @@ func TestJSONCodec(t *testing.T) {
|
||||||
|
|
||||||
// Test driver.Valuer is used before json.Marshaler (https://github.com/jackc/pgx/issues/1805)
|
// Test driver.Valuer is used before json.Marshaler (https://github.com/jackc/pgx/issues/1805)
|
||||||
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
||||||
|
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
||||||
|
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
||||||
})
|
})
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -109,6 +111,31 @@ func (i Issue1805) MarshalJSON() ([]byte, error) {
|
||||||
return nil, errors.New("MarshalJSON called")
|
return nil, errors.New("MarshalJSON called")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Issue2146 int
|
||||||
|
|
||||||
|
func (i *Issue2146) Scan(src any) error {
|
||||||
|
var source []byte
|
||||||
|
switch src.(type) {
|
||||||
|
case string:
|
||||||
|
source = []byte(src.(string))
|
||||||
|
case []byte:
|
||||||
|
source = src.([]byte)
|
||||||
|
default:
|
||||||
|
return errors.New("unknown source type")
|
||||||
|
}
|
||||||
|
var newI int
|
||||||
|
if err := json.Unmarshal(source, &newI); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = Issue2146(newI + 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i Issue2146) Value() (driver.Value, error) {
|
||||||
|
b, err := json.Marshal(int(i - 1))
|
||||||
|
return string(b), err
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
||||||
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
|
@ -396,7 +396,12 @@ type scanPlanSQLScanner struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||||
scanner := dst.(sql.Scanner)
|
scanner := getSQLScanner(dst)
|
||||||
|
|
||||||
|
if scanner == nil {
|
||||||
|
return fmt.Errorf("cannot scan into %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
if src == nil {
|
if src == nil {
|
||||||
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
|
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
|
||||||
// text format path would be converted to empty string.
|
// text format path would be converted to empty string.
|
||||||
|
@ -408,6 +413,21 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
||||||
|
func getSQLScanner(target any) sql.Scanner {
|
||||||
|
val := reflect.ValueOf(target)
|
||||||
|
for val.Kind() == reflect.Ptr {
|
||||||
|
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||||
|
if val.IsNil() {
|
||||||
|
val.Set(reflect.New(val.Type().Elem()))
|
||||||
|
}
|
||||||
|
return val.Interface().(sql.Scanner)
|
||||||
|
}
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type scanPlanString struct{}
|
type scanPlanString struct{}
|
||||||
|
|
||||||
func (scanPlanString) Scan(src []byte, dst any) error {
|
func (scanPlanString) Scan(src []byte, dst any) error {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -631,3 +632,10 @@ func isExpectedEq(a any) func(any) bool {
|
||||||
return a == v
|
return a == v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPtrExpectedEq(a any) func(any) bool {
|
||||||
|
return func(v any) bool {
|
||||||
|
val := reflect.ValueOf(v)
|
||||||
|
return a == val.Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue