mirror of https://github.com/jackc/pgx.git
parent
d846dbcb75
commit
824d8ad40d
54
pgtype.go
54
pgtype.go
|
@ -533,8 +533,22 @@ type scanPlanDataTypeSQLScanner DataType
|
|||
func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
|
||||
scanner, ok := dst.(sql.Scanner)
|
||||
if !ok {
|
||||
newPlan := ci.PlanScan(oid, formatCode, dst)
|
||||
return newPlan.Scan(ci, oid, formatCode, src, dst)
|
||||
dv := reflect.ValueOf(dst)
|
||||
if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
|
||||
newPlan := ci.PlanScan(oid, formatCode, dst)
|
||||
return newPlan.Scan(ci, oid, formatCode, src, dst)
|
||||
}
|
||||
if src == nil {
|
||||
// Ensure the pointer points to a zero version of the value
|
||||
dv.Elem().Set(reflect.Zero(dv.Type().Elem()))
|
||||
return nil
|
||||
}
|
||||
dv = dv.Elem()
|
||||
// If the pointer is to a nil pointer then set that before scanning
|
||||
if dv.Kind() == reflect.Ptr && dv.IsNil() {
|
||||
dv.Set(reflect.New(dv.Type().Elem()))
|
||||
}
|
||||
scanner = dv.Interface().(sql.Scanner)
|
||||
}
|
||||
|
||||
dt := (*DataType)(plan)
|
||||
|
@ -593,7 +607,25 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode
|
|||
type scanPlanSQLScanner struct{}
|
||||
|
||||
func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
|
||||
scanner := dst.(sql.Scanner)
|
||||
scanner, ok := dst.(sql.Scanner)
|
||||
if !ok {
|
||||
dv := reflect.ValueOf(dst)
|
||||
if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
|
||||
newPlan := ci.PlanScan(oid, formatCode, dst)
|
||||
return newPlan.Scan(ci, oid, formatCode, src, dst)
|
||||
}
|
||||
if src == nil {
|
||||
// Ensure the pointer points to a zero version of the value
|
||||
dv.Elem().Set(reflect.Zero(dv.Type()))
|
||||
return nil
|
||||
}
|
||||
dv = dv.Elem()
|
||||
// If the pointer is to a nil pointer then set that before scanning
|
||||
if dv.Kind() == reflect.Ptr && dv.IsNil() {
|
||||
dv.Set(reflect.New(dv.Type().Elem()))
|
||||
}
|
||||
scanner = dv.Interface().(sql.Scanner)
|
||||
}
|
||||
if src == nil {
|
||||
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
|
||||
// text format path would be converted to empty string.
|
||||
|
@ -761,6 +793,18 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt
|
|||
return newPlan.Scan(ci, oid, formatCode, src, dst)
|
||||
}
|
||||
|
||||
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
|
||||
func isScanner(dst interface{}) bool {
|
||||
if _, ok := dst.(sql.Scanner); ok {
|
||||
return true
|
||||
}
|
||||
if t := reflect.TypeOf(dst); t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PlanScan prepares a plan to scan a value into dst.
|
||||
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
|
||||
switch formatCode {
|
||||
|
@ -825,13 +869,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan
|
|||
}
|
||||
|
||||
if dt != nil {
|
||||
if _, ok := dst.(sql.Scanner); ok {
|
||||
if isScanner(dst) {
|
||||
return (*scanPlanDataTypeSQLScanner)(dt)
|
||||
}
|
||||
return (*scanPlanDataTypeAssignTo)(dt)
|
||||
}
|
||||
|
||||
if _, ok := dst.(sql.Scanner); ok {
|
||||
if isScanner(dst) {
|
||||
return scanPlanSQLScanner{}
|
||||
}
|
||||
|
||||
|
|
|
@ -310,3 +310,44 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
type pgCustomInt int64
|
||||
|
||||
func (ci *pgCustomInt) Scan(src interface{}) error {
|
||||
*ci = pgCustomInt(src.(int64))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestScanPlanBinaryInt32ScanScanner(t *testing.T) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
src := []byte{0, 42}
|
||||
var v pgCustomInt
|
||||
|
||||
plan := ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &v)
|
||||
err := plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &v)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 42, v)
|
||||
|
||||
ptr := new(pgCustomInt)
|
||||
plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
|
||||
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 42, *ptr)
|
||||
|
||||
ptr = new(pgCustomInt)
|
||||
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ptr)
|
||||
|
||||
ptr = nil
|
||||
plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
|
||||
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, src, &ptr)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 42, *ptr)
|
||||
|
||||
ptr = nil
|
||||
plan = ci.PlanScan(pgtype.Int2OID, pgtype.BinaryFormatCode, &ptr)
|
||||
err = plan.Scan(ci, pgtype.Int2OID, pgtype.BinaryFormatCode, nil, &ptr)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ptr)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue