diff --git a/numeric.go b/numeric.go index 3d209ff2..83791d4e 100644 --- a/numeric.go +++ b/numeric.go @@ -235,7 +235,24 @@ func (dst Numeric) Get() interface{} { return dst } +var NumericDecoderWrapper func(interface{}) NumericDecoder + +type NumericDecoder interface { + DecodeNumeric(*Numeric) error +} + func (src *Numeric) AssignTo(dst interface{}) error { + if d, ok := dst.(NumericDecoder); ok { + return d.DecodeNumeric(src) + } else { + if NumericDecoderWrapper != nil { + d = NumericDecoderWrapper(dst) + if d != nil { + return d.DecodeNumeric(src) + } + } + } + if !src.Valid { return NullAssignTo(dst) } diff --git a/pgtype.go b/pgtype.go index c4fe870d..39e0ad79 100644 --- a/pgtype.go +++ b/pgtype.go @@ -225,15 +225,18 @@ type ConnInfo struct { oidToResultFormatCode map[uint32]int16 reflectTypeToDataType map[reflect.Type]*DataType + + preferAssignToOverSQLScannerTypes map[reflect.Type]struct{} } func newConnInfo() *ConnInfo { return &ConnInfo{ - oidToDataType: make(map[uint32]*DataType), - nameToDataType: make(map[string]*DataType), - reflectTypeToName: make(map[reflect.Type]string), - oidToParamFormatCode: make(map[uint32]int16), - oidToResultFormatCode: make(map[uint32]int16), + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToParamFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), + preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), } } @@ -462,6 +465,12 @@ func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { return TextFormatCode } +// PreferAssignToOverSQLScannerForType makes a sql.Scanner type use the AssignTo scan path instead of sql.Scanner. +// This is primarily for efficient integration with 3rd party numeric and UUID types. +func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { + ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} +} + // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := newConnInfo() @@ -478,6 +487,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2.reflectTypeToName[t] = n } + for t, _ := range ci.preferAssignToOverSQLScannerTypes { + ci2.preferAssignToOverSQLScannerTypes[t] = struct{}{} + } + return ci2 } @@ -808,7 +821,9 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if dt != nil { if _, ok := dst.(sql.Scanner); ok { - return (*scanPlanDataTypeSQLScanner)(dt) + if _, found := ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(dst)]; !found { + return (*scanPlanDataTypeSQLScanner)(dt) + } } return (*scanPlanDataTypeAssignTo)(dt) }