Fix encoding uint64 larger than math.MaxInt64 into numeric

fixes https://github.com/jackc/pgx/issues/1357
pull/1364/head
Jack Christensen 2022-10-29 08:47:12 -05:00
parent c00fb5d2a1
commit 6fabd8f5b1
3 changed files with 57 additions and 2 deletions

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math"
"math/big"
"net"
"net/netip"
"reflect"
@ -223,6 +224,29 @@ func (w uint64Wrapper) Int64Value() (Int8, error) {
return Int8{Int64: int64(w), Valid: true}, nil
}
func (w *uint64Wrapper) ScanNumeric(v Numeric) error {
if !v.Valid {
return fmt.Errorf("cannot scan NULL into *uint64")
}
bi, err := v.toBigInt()
if err != nil {
return fmt.Errorf("cannot scan into *uint64: %v", err)
}
if !bi.IsUint64() {
return fmt.Errorf("cannot scan %v into *uint64", bi.String())
}
*w = uint64Wrapper(v.Int.Uint64())
return nil
}
func (w uint64Wrapper) NumericValue() (Numeric, error) {
return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil
}
type uintWrapper uint
func (w uintWrapper) SkipUnderlyingTypePlan() {}
@ -253,6 +277,35 @@ func (w uintWrapper) Int64Value() (Int8, error) {
return Int8{Int64: int64(w), Valid: true}, nil
}
func (w *uintWrapper) ScanNumeric(v Numeric) error {
if !v.Valid {
return fmt.Errorf("cannot scan NULL into *uint")
}
bi, err := v.toBigInt()
if err != nil {
return fmt.Errorf("cannot scan into *uint: %v", err)
}
if !bi.IsUint64() {
return fmt.Errorf("cannot scan %v into *uint", bi.String())
}
ui := v.Int.Uint64()
if math.MaxUint < ui {
return fmt.Errorf("cannot scan %v into *uint", ui)
}
*w = uintWrapper(ui)
return nil
}
func (w uintWrapper) NumericValue() (Numeric, error) {
return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil
}
type float32Wrapper float32
func (w float32Wrapper) SkipUnderlyingTypePlan() {}

View File

@ -110,6 +110,8 @@ func TestNumericCodec(t *testing.T) {
{int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))},
{int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))},
{int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))},
{uint64(math.MaxUint64), new(uint64), isExpectedEq(uint64(math.MaxUint64))},
{uint(math.MaxUint), new(uint), isExpectedEq(uint(math.MaxUint))},
{"1.23", new(string), isExpectedEq("1.23")},
{pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})},
{nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})},

View File

@ -351,8 +351,8 @@ func NewMap() *Map {
registerDefaultPgTypeVariants[uint8](m, "int8")
registerDefaultPgTypeVariants[uint16](m, "int8")
registerDefaultPgTypeVariants[uint32](m, "int8")
registerDefaultPgTypeVariants[uint64](m, "int8")
registerDefaultPgTypeVariants[uint](m, "int8")
registerDefaultPgTypeVariants[uint64](m, "numeric")
registerDefaultPgTypeVariants[uint](m, "numeric")
registerDefaultPgTypeVariants[float32](m, "float4")
registerDefaultPgTypeVariants[float64](m, "float8")