diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 17fe4535..5453bf18 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -8,7 +8,9 @@ import ( type int8Wrapper int8 -func (n *int8Wrapper) ScanInt64(v Int8) error { +func (w int8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int8Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int8") } @@ -19,18 +21,20 @@ func (n *int8Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxInt8 { return fmt.Errorf("%d is greater than maximum value for int8", v.Int) } - *n = int8Wrapper(v.Int) + *w = int8Wrapper(v.Int) return nil } -func (n int8Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type int16Wrapper int16 -func (n *int16Wrapper) ScanInt64(v Int8) error { +func (w int16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int16Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int16") } @@ -41,18 +45,20 @@ func (n *int16Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for int16", v.Int) } - *n = int16Wrapper(v.Int) + *w = int16Wrapper(v.Int) return nil } -func (n int16Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type int32Wrapper int32 -func (n *int32Wrapper) ScanInt64(v Int8) error { +func (w int32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int32Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int32") } @@ -63,34 +69,38 @@ func (n *int32Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for int32", v.Int) } - *n = int32Wrapper(v.Int) + *w = int32Wrapper(v.Int) return nil } -func (n int32Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type int64Wrapper int64 -func (n *int64Wrapper) ScanInt64(v Int8) error { +func (w int64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *int64Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int64") } - *n = int64Wrapper(v.Int) + *w = int64Wrapper(v.Int) return nil } -func (n int64Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w int64Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type intWrapper int -func (n *intWrapper) ScanInt64(v Int8) error { +func (w intWrapper) SkipUnderlyingTypePlan() {} + +func (w *intWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *int") } @@ -102,18 +112,20 @@ func (n *intWrapper) ScanInt64(v Int8) error { return fmt.Errorf("%d is greater than maximum value for int", v.Int) } - *n = intWrapper(v.Int) + *w = intWrapper(v.Int) return nil } -func (n intWrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w intWrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint8Wrapper uint8 -func (n *uint8Wrapper) ScanInt64(v Int8) error { +func (w uint8Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint8Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint8") } @@ -124,18 +136,20 @@ func (n *uint8Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxUint8 { return fmt.Errorf("%d is greater than maximum value for uint8", v.Int) } - *n = uint8Wrapper(v.Int) + *w = uint8Wrapper(v.Int) return nil } -func (n uint8Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w uint8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint16Wrapper uint16 -func (n *uint16Wrapper) ScanInt64(v Int8) error { +func (w uint16Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint16Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint16") } @@ -146,18 +160,20 @@ func (n *uint16Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxUint16 { return fmt.Errorf("%d is greater than maximum value for uint16", v.Int) } - *n = uint16Wrapper(v.Int) + *w = uint16Wrapper(v.Int) return nil } -func (n uint16Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w uint16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint32Wrapper uint32 -func (n *uint32Wrapper) ScanInt64(v Int8) error { +func (w uint32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint32Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint32") } @@ -168,18 +184,20 @@ func (n *uint32Wrapper) ScanInt64(v Int8) error { if v.Int > math.MaxUint32 { return fmt.Errorf("%d is greater than maximum value for uint32", v.Int) } - *n = uint32Wrapper(v.Int) + *w = uint32Wrapper(v.Int) return nil } -func (n uint32Wrapper) Int64Value() (Int8, error) { - return Int8{Int: int64(n), Valid: true}, nil +func (w uint32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(w), Valid: true}, nil } type uint64Wrapper uint64 -func (n *uint64Wrapper) ScanInt64(v Int8) error { +func (w uint64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *uint64Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint64") } @@ -188,22 +206,24 @@ func (n *uint64Wrapper) ScanInt64(v Int8) error { return fmt.Errorf("%d is less than minimum value for uint64", v.Int) } - *n = uint64Wrapper(v.Int) + *w = uint64Wrapper(v.Int) return nil } -func (n uint64Wrapper) Int64Value() (Int8, error) { - if uint64(n) > uint64(math.MaxInt64) { - return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) +func (w uint64Wrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type uintWrapper uint -func (n *uintWrapper) ScanInt64(v Int8) error { +func (w uintWrapper) SkipUnderlyingTypePlan() {} + +func (w *uintWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *uint64") } @@ -216,73 +236,79 @@ func (n *uintWrapper) ScanInt64(v Int8) error { return fmt.Errorf("%d is greater than maximum value for uint", v.Int) } - *n = uintWrapper(v.Int) + *w = uintWrapper(v.Int) return nil } -func (n uintWrapper) Int64Value() (Int8, error) { - if uint64(n) > uint64(math.MaxInt64) { - return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) +func (w uintWrapper) Int64Value() (Int8, error) { + if uint64(w) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type float32Wrapper float32 -func (n *float32Wrapper) ScanInt64(v Int8) error { +func (w float32Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float32Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *float32") } - *n = float32Wrapper(v.Int) + *w = float32Wrapper(v.Int) return nil } -func (n float32Wrapper) Int64Value() (Int8, error) { - if n > math.MaxInt64 { - return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) +func (w float32Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type float64Wrapper float64 -func (n *float64Wrapper) ScanInt64(v Int8) error { +func (w float64Wrapper) SkipUnderlyingTypePlan() {} + +func (w *float64Wrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *float64") } - *n = float64Wrapper(v.Int) + *w = float64Wrapper(v.Int) return nil } -func (n float64Wrapper) Int64Value() (Int8, error) { - if n > math.MaxInt64 { - return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) +func (w float64Wrapper) Int64Value() (Int8, error) { + if w > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w) } - return Int8{Int: int64(n), Valid: true}, nil + return Int8{Int: int64(w), Valid: true}, nil } type stringWrapper string -func (s *stringWrapper) ScanInt64(v Int8) error { +func (w stringWrapper) SkipUnderlyingTypePlan() {} + +func (w *stringWrapper) ScanInt64(v Int8) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *string") } - *s = stringWrapper(strconv.FormatInt(v.Int, 10)) + *w = stringWrapper(strconv.FormatInt(v.Int, 10)) return nil } -func (s stringWrapper) Int64Value() (Int8, error) { - num, err := strconv.ParseInt(string(s), 10, 64) +func (w stringWrapper) Int64Value() (Int8, error) { + num, err := strconv.ParseInt(string(w), 10, 64) if err != nil { return Int8{}, err } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bb0d2a9d..c9da1322 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -747,7 +747,12 @@ func tryPointerPointerScanPlan(dst interface{}) (plan *pointerPointerScanPlan, n return nil, nil, false } -var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ +// SkipUnderlyingTypePlanner prevents PlanScan and PlanDecode from trying to use the underlying type. +type SkipUnderlyingTypePlanner interface { + SkipUnderlyingTypePlan() +} + +var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Int: reflect.TypeOf(new(int)), reflect.Int8: reflect.TypeOf(new(int8)), reflect.Int16: reflect.TypeOf(new(int16)), @@ -763,13 +768,13 @@ var elemKindToBasePointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind] reflect.String: reflect.TypeOf(new(string)), } -type baseTypeScanPlan struct { +type underlyingTypeScanPlan struct { dstType reflect.Type nextDstType reflect.Type next ScanPlan } -func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { +func (plan *underlyingTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { if plan.dstType != reflect.TypeOf(dst) { newPlan := ci.PlanScan(oid, formatCode, dst) return newPlan.Scan(ci, oid, formatCode, src, dst) @@ -778,14 +783,18 @@ func (plan *baseTypeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return plan.next.Scan(ci, oid, formatCode, src, reflect.ValueOf(dst).Convert(plan.nextDstType).Interface()) } -func tryBaseTypeScanPlan(dst interface{}) (plan *baseTypeScanPlan, nextDst interface{}, ok bool) { +func tryUnderlyingTypeScanPlan(dst interface{}) (plan *underlyingTypeScanPlan, nextDst interface{}, ok bool) { + if _, ok := dst.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { elemValue := dstValue.Elem() - nextDstType := elemKindToBasePointerTypes[elemValue.Kind()] + nextDstType := elemKindToPointerTypes[elemValue.Kind()] if nextDstType != nil && dstValue.Type() != nextDstType { - return &baseTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true + return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } } @@ -881,7 +890,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan } } - if baseTypePlan, nextDst, ok := tryBaseTypeScanPlan(dst); ok { + if baseTypePlan, nextDst, ok := tryUnderlyingTypeScanPlan(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { baseTypePlan.next = nextPlan return baseTypePlan @@ -1004,7 +1013,7 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco } } - if baseTypePlan, nextValue, ok := tryBaseTypeEncodePlan(value); ok { + if baseTypePlan, nextValue, ok := tryUnderlyingTypeEncodePlan(value); ok { if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { baseTypePlan.next = nextPlan return baseTypePlan @@ -1046,7 +1055,7 @@ func tryDerefPointerEncodePlan(value interface{}) (plan *derefPointerEncodePlan, return nil, nil, false } -var kindToBaseTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ +var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Int: reflect.TypeOf(int(0)), reflect.Int8: reflect.TypeOf(int8(0)), reflect.Int16: reflect.TypeOf(int16(0)), @@ -1062,21 +1071,25 @@ var kindToBaseTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Typ reflect.String: reflect.TypeOf(""), } -type baseTypeEncodePlan struct { +type underlyingTypeEncodePlan struct { nextValueType reflect.Type next EncodePlan } -func (plan *baseTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { +func (plan *underlyingTypeEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { return plan.next.Encode(reflect.ValueOf(value).Convert(plan.nextValueType).Interface(), buf) } -func tryBaseTypeEncodePlan(value interface{}) (plan *baseTypeEncodePlan, nextValue interface{}, ok bool) { +func tryUnderlyingTypeEncodePlan(value interface{}) (plan *underlyingTypeEncodePlan, nextValue interface{}, ok bool) { + if _, ok := value.(SkipUnderlyingTypePlanner); ok { + return nil, nil, false + } + refValue := reflect.ValueOf(value) - nextValueType := kindToBaseTypes[refValue.Kind()] + nextValueType := kindToTypes[refValue.Kind()] if nextValueType != nil && refValue.Type() != nextValueType { - return &baseTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true + return &underlyingTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true } return nil, nil, false