diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 7a992b09..d93f8d59 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math" + "math/big" "net" "net/netip" "reflect" @@ -223,6 +224,29 @@ func (w uint64Wrapper) Int64Value() (Int8, error) { return Int8{Int64: int64(w), Valid: true}, nil } +func (w *uint64Wrapper) ScanNumeric(v Numeric) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + bi, err := v.toBigInt() + if err != nil { + return fmt.Errorf("cannot scan into *uint64: %v", err) + } + + if !bi.IsUint64() { + return fmt.Errorf("cannot scan %v into *uint64", bi.String()) + } + + *w = uint64Wrapper(v.Int.Uint64()) + + return nil +} + +func (w uint64Wrapper) NumericValue() (Numeric, error) { + return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil +} + type uintWrapper uint func (w uintWrapper) SkipUnderlyingTypePlan() {} @@ -253,6 +277,35 @@ func (w uintWrapper) Int64Value() (Int8, error) { return Int8{Int64: int64(w), Valid: true}, nil } +func (w *uintWrapper) ScanNumeric(v Numeric) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint") + } + + bi, err := v.toBigInt() + if err != nil { + return fmt.Errorf("cannot scan into *uint: %v", err) + } + + if !bi.IsUint64() { + return fmt.Errorf("cannot scan %v into *uint", bi.String()) + } + + ui := v.Int.Uint64() + + if math.MaxUint < ui { + return fmt.Errorf("cannot scan %v into *uint", ui) + } + + *w = uintWrapper(ui) + + return nil +} + +func (w uintWrapper) NumericValue() (Numeric, error) { + return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil +} + type float32Wrapper float32 func (w float32Wrapper) SkipUnderlyingTypePlan() {} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 071f0c24..46faa5ce 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -110,6 +110,8 @@ func TestNumericCodec(t *testing.T) { {int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))}, {int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))}, {int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))}, + {uint64(math.MaxUint64), new(uint64), isExpectedEq(uint64(math.MaxUint64))}, + {uint(math.MaxUint), new(uint), isExpectedEq(uint(math.MaxUint))}, {"1.23", new(string), isExpectedEq("1.23")}, {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index a234b855..2aa96e68 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -351,8 +351,8 @@ func NewMap() *Map { registerDefaultPgTypeVariants[uint8](m, "int8") registerDefaultPgTypeVariants[uint16](m, "int8") registerDefaultPgTypeVariants[uint32](m, "int8") - registerDefaultPgTypeVariants[uint64](m, "int8") - registerDefaultPgTypeVariants[uint](m, "int8") + registerDefaultPgTypeVariants[uint64](m, "numeric") + registerDefaultPgTypeVariants[uint](m, "numeric") registerDefaultPgTypeVariants[float32](m, "float4") registerDefaultPgTypeVariants[float64](m, "float8")