From d4fcd4a897e44c0d416030dad4edac66ed01321f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 23 Dec 2022 14:15:45 -0600 Subject: [PATCH] Support sql.Scanner on renamed base type https://github.com/jackc/pgtype/issues/197 --- pgtype/pgtype.go | 14 ++++++++++---- pgtype/pgtype_test.go | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ad451678..7e8ef5ab 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1227,6 +1227,16 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { } } + // This needs to happen before trying m.TryWrapScanPlanFuncs. Otherwise, a sql.Scanner would not get called if it was + // defined on a type that could be unwrapped such as `type myString string`. + // + // https://github.com/jackc/pgtype/issues/197 + if dt == nil { + if _, ok := target.(sql.Scanner); ok { + return &scanPlanSQLScanner{formatCode: formatCode} + } + } + for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { @@ -1248,10 +1258,6 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { } } - if _, ok := target.(sql.Scanner); ok { - return &scanPlanSQLScanner{formatCode: formatCode} - } - return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index c0cb3852..5574b676 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -223,6 +223,23 @@ func TestMapScanUnknownOIDIntoSQLScanner(t *testing.T) { assert.False(t, s.Valid) } +type scannerString string + +func (ss *scannerString) Scan(v any) error { + *ss = scannerString("scanned") + return nil +} + +// https://github.com/jackc/pgtype/issues/197 +func TestMapScanUnregisteredOIDIntoRenamedStringSQLScanner(t *testing.T) { + m := pgtype.NewMap() + + var s scannerString + err := m.Scan(unregisteredOID, pgx.TextFormatCode, []byte(nil), &s) + assert.NoError(t, err) + assert.Equal(t, "scanned", string(s)) +} + type pgCustomInt int64 func (ci *pgCustomInt) Scan(src interface{}) error {