support *sql.Scanner for null handling

Fixes jackc/pgx#1211
pull/1281/head
James Hartig 2022-05-25 09:16:02 -04:00 committed by Jack Christensen
parent d846dbcb75
commit 824d8ad40d
2 changed files with 90 additions and 5 deletions

View File

@ -533,8 +533,22 @@ type scanPlanDataTypeSQLScanner DataType
func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
scanner, ok := dst.(sql.Scanner)
if !ok {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
if src == nil {
// Ensure the pointer points to a zero version of the value
dv.Elem().Set(reflect.Zero(dv.Type().Elem()))
return nil
}
dv = dv.Elem()
// If the pointer is to a nil pointer then set that before scanning
if dv.Kind() == reflect.Ptr && dv.IsNil() {
dv.Set(reflect.New(dv.Type().Elem()))
}
scanner = dv.Interface().(sql.Scanner)
}
dt := (*DataType)(plan)
@ -593,7 +607,25 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode
type scanPlanSQLScanner struct{}
func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
scanner := dst.(sql.Scanner)
scanner, ok := dst.(sql.Scanner)
if !ok {
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
if src == nil {
// Ensure the pointer points to a zero version of the value
dv.Elem().Set(reflect.Zero(dv.Type()))
return nil
}
dv = dv.Elem()
// If the pointer is to a nil pointer then set that before scanning
if dv.Kind() == reflect.Ptr && dv.IsNil() {
dv.Set(reflect.New(dv.Type().Elem()))
}
scanner = dv.Interface().(sql.Scanner)
}
if src == nil {
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
// text format path would be converted to empty string.
@ -761,6 +793,18 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func isScanner(dst interface{}) bool {
if _, ok := dst.(sql.Scanner); ok {
return true
}
if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) {
return true
}
return false
}
// PlanScan prepares a plan to scan a value into dst.
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
switch formatCode {
@ -825,13 +869,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan
}
if dt != nil {
if _, ok := dst.(sql.Scanner); ok {
if isScanner(dst) {
return (*scanPlanDataTypeSQLScanner)(dt)
}
return (*scanPlanDataTypeAssignTo)(dt)
}
if _, ok := dst.(sql.Scanner); ok {
if isScanner(dst) {
return scanPlanSQLScanner{}
}

View File

@ -310,3 +310,44 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
}
}
}
type pgCustomInt int64
func (ci *pgCustomInt) Scan(src interface{}) error {
*ci = pgCustomInt(src.(int64))
return nil
}
func TestScanPlanBinaryInt32ScanScanner(t *testing.T) {
ci := pgtype.NewConnInfo()
src := []byte{0, 42}
var v pgCustomInt
plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v)
err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &v)
require.NoError(t, err)
require.EqualValues(t, 42, v)
ptr := new(pgCustomInt)
plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = new(pgCustomInt)
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
ptr = nil
plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr)
require.NoError(t, err)
require.EqualValues(t, 42, *ptr)
ptr = nil
plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr)
require.NoError(t, err)
assert.Nil(t, ptr)
}