diff --git a/pgtype.go b/pgtype.go index eba09fa5..4078da7b 100644 --- a/pgtype.go +++ b/pgtype.go @@ -533,8 +533,22 @@ type scanPlanDataTypeSQLScanner DataType func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { scanner, ok := dst.(sql.Scanner) if !ok { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + if src == nil { + // Ensure the pointer points to a zero version of the value + dv.Elem().Set(reflect.Zero(dv.Type().Elem())) + return nil + } + dv = dv.Elem() + // If the pointer is to a nil pointer then set that before scanning + if dv.Kind() == reflect.Ptr && dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + scanner = dv.Interface().(sql.Scanner) } dt := (*DataType)(plan) @@ -593,7 +607,25 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode type scanPlanSQLScanner struct{} func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - scanner := dst.(sql.Scanner) + scanner, ok := dst.(sql.Scanner) + if !ok { + dv := reflect.ValueOf(dst) + if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + if src == nil { + // Ensure the pointer points to a zero version of the value + dv.Elem().Set(reflect.Zero(dv.Type())) + return nil + } + dv = dv.Elem() + // If the pointer is to a nil pointer then set that before scanning + if dv.Kind() == reflect.Ptr && dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + scanner = dv.Interface().(sql.Scanner) + } if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. @@ -761,6 +793,18 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func isScanner(dst interface{}) bool { + if _, ok := dst.(sql.Scanner); ok { + return true + } + if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) { + return true + } + return false +} + // PlanScan prepares a plan to scan a value into dst. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { switch formatCode { @@ -825,13 +869,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } if dt != nil { - if _, ok := dst.(sql.Scanner); ok { + if isScanner(dst) { return (*scanPlanDataTypeSQLScanner)(dt) } return (*scanPlanDataTypeAssignTo)(dt) } - if _, ok := dst.(sql.Scanner); ok { + if isScanner(dst) { return scanPlanSQLScanner{} } diff --git a/pgtype_test.go b/pgtype_test.go index 85ca55e9..9127766f 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -310,3 +310,44 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } } + +type pgCustomInt int64 + +func (ci *pgCustomInt) Scan(src interface{}) error { + *ci = pgCustomInt(src.(int64)) + return nil +} + +func TestScanPlanBinaryInt32ScanScanner(t *testing.T) { + ci := pgtype.NewConnInfo() + src := []byte{0, 42} + var v pgCustomInt + + plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v) + err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &v) + require.NoError(t, err) + require.EqualValues(t, 42, v) + + ptr := new(pgCustomInt) + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = new(pgCustomInt) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr) + require.NoError(t, err) + require.EqualValues(t, 42, *ptr) + + ptr = nil + plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr) + err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr) + require.NoError(t, err) + assert.Nil(t, ptr) +}