diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5d0ed882..c0d02197 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -730,11 +730,15 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt return newPlan.Scan(ci, oid, formatCode, src, dst) } +type tryWrapScanPlanFunc func(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) + type pointerPointerScanPlan struct { dstType reflect.Type next ScanPlan } +func (plan *pointerPointerScanPlan) SetNext(next ScanPlan) { plan.next = next } + func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if plan.dstType != reflect.TypeOf(dst) { newPlan := ci.PlanScan(oid, formatCode, dst) @@ -751,7 +755,7 @@ func (plan *pointerPointerScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in return plan.next.Scan(ci, oid, formatCode, src, el.Interface()) } -func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, nextDst interface{}, ok bool) { +func tryPointerPointerScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { if dstValue := reflect.ValueOf(dst); dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() if elemValue.Kind() == reflect.Ptr { @@ -790,6 +794,8 @@ type underlyingTypeScanPlan struct { next ScanPlan } +func (plan *underlyingTypeScanPlan) SetNext(next ScanPlan) { plan.next = next } + func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if plan.dstType != reflect.TypeOf(dst) { newPlan := ci.PlanScan(oid, formatCode, dst) @@ -799,7 +805,7 @@ func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode in return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } -func tryUnderlyingTypeScanPlan(dst interface{}) (plan *underlyingTypeScanPlan, nextDst interface{}, ok bool) { +func tryUnderlyingTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter, nextDst interface{}, ok bool) { if _, ok := dst.(SkipUnderlyingTypePlanner); ok { return nil, nil, false } @@ -1128,24 +1134,18 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan return plan } - if pointerPointerPlan, nextDst, ok := tryPointerPointerScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - pointerPointerPlan.next = nextPlan - return pointerPointerPlan - } + tryWrappers := []tryWrapScanPlanFunc{ + tryPointerPointerScanPlan, + tryUnderlyingTypeScanPlan, + tryWrapBuiltinTypeScanPlan, } - if baseTypePlan, nextDst, ok := tryUnderlyingTypeScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - baseTypePlan.next = nextPlan - return baseTypePlan - } - } - - if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeScanPlan(dst); ok { - if nextPlan := ci.PlanScan(oid, formatCode, nextValue); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan + for _, f := range tryWrappers { + if wrapperPlan, nextDst, ok := f(dst); ok { + if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } } @@ -1259,36 +1259,33 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco return plan } - if derefPointerPlan, nextValue, ok := tryDerefPointerEncodePlan(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { - derefPointerPlan.next = nextPlan - return derefPointerPlan - } + tryWrappers := []tryWrapEncodePlanFunc{ + tryDerefPointerEncodePlan, + tryUnderlyingTypeEncodePlan, + tryWrapBuiltinTypeEncodePlan, } - if baseTypePlan, nextValue, ok := tryUnderlyingTypeEncodePlan(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { - baseTypePlan.next = nextPlan - return baseTypePlan + for _, f := range tryWrappers { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } } - - if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeEncodePlan(value); ok { - if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } return nil } +type tryWrapEncodePlanFunc func(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) + type derefPointerEncodePlan struct { next EncodePlan } +func (plan *derefPointerEncodePlan) SetNext(next EncodePlan) { plan.next = next } + func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { ptr := reflect.ValueOf(value) @@ -1299,7 +1296,7 @@ func (plan *derefPointerEncodePlan) Encode(value interface{}, buf []byte) (newBu return plan.next.Encode(ptr.Elem().Interface(), buf) } -func tryDerefPointerEncodePlan(value interface{}) (plan *derefPointerEncodePlan, nextValue interface{}, ok bool) { +func tryDerefPointerEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { if valueType := reflect.TypeOf(value); valueType.Kind() == reflect.Ptr { return &derefPointerEncodePlan{}, reflect.New(valueType.Elem()).Elem().Interface(), true } @@ -1328,11 +1325,13 @@ type underlyingTypeEncodePlan struct { next EncodePlan } +func (plan *underlyingTypeEncodePlan) SetNext(next EncodePlan) { plan.next = next } + func (plan *underlyingTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) } -func tryUnderlyingTypeEncodePlan(value interface{}) (plan *underlyingTypeEncodePlan, nextValue interface{}, ok bool) { +func tryUnderlyingTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { if _, ok := value.(SkipUnderlyingTypePlanner); ok { return nil, nil, false }