From a74ebc9e51fe210504c63a13220390bbc8b1cef6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Feb 2022 08:39:53 -0600 Subject: [PATCH] pgtype.Numeric implements Float64Valuer --- pgtype/numeric.go | 55 ++++++++++++++++++++++++------------------ pgtype/numeric_test.go | 25 +++++++++++++++++++ 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 5bdbd4d5..d2311f3a 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -80,6 +80,35 @@ func (n Numeric) NumericValue() (Numeric, error) { return n, nil } +func (n Numeric) Float64Value() (Float8, error) { + if !n.Valid { + return Float8{}, nil + } else if n.NaN { + return Float8{Float: math.NaN(), Valid: true}, nil + } else if n.InfinityModifier == Infinity { + return Float8{Float: math.Inf(1), Valid: true}, nil + } else if n.InfinityModifier == NegativeInfinity { + return Float8{Float: math.Inf(-1), Valid: true}, nil + } + + buf := make([]byte, 0, 32) + + if n.Int == nil { + buf = append(buf, '0') + } else { + buf = append(buf, n.Int.String()...) + } + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return Float8{}, err + } + + return Float8{Float: f, Valid: true}, nil +} + func (n *Numeric) toBigInt() (*big.Int, error) { if n.Exp == 0 { return n.Int, nil @@ -104,28 +133,6 @@ func (n *Numeric) toBigInt() (*big.Int, error) { return num, nil } -func (n *Numeric) toFloat64() (float64, error) { - if n.NaN { - return math.NaN(), nil - } else if n.InfinityModifier == Infinity { - return math.Inf(1), nil - } else if n.InfinityModifier == NegativeInfinity { - return math.Inf(-1), nil - } - - buf := make([]byte, 0, 32) - - buf = append(buf, n.Int.String()...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) - - f, err := strconv.ParseFloat(string(buf), 64) - if err != nil { - return 0, err - } - return f, nil -} - func parseNumericString(str string) (n *big.Int, exp int32, err error) { parts := strings.SplitN(str, ".", 2) digits := strings.Join(parts, "") @@ -642,12 +649,12 @@ func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst interface{}) e return err } - f64, err := n.toFloat64() + f8, err := n.Float64Value() if err != nil { return err } - return scanner.ScanFloat64(Float8{Float: f64, Valid: true}) + return scanner.ScanFloat64(f8) } type scanPlanBinaryNumericToInt64Scanner struct{} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index c74fb9a3..0449059e 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -117,6 +118,30 @@ func TestNumericCodec(t *testing.T) { }) } +func TestNumericFloat64Valuer(t *testing.T) { + for i, tt := range []struct { + n pgtype.Numeric + f pgtype.Float8 + }{ + {mustParseNumeric(t, "1"), pgtype.Float8{Float: 1, Valid: true}}, + {mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float: 0.0000000000000000001, Valid: true}}, + {mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float: -99999999999, Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float: math.Inf(1), Valid: true}}, + {pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float: math.Inf(-1), Valid: true}}, + {pgtype.Numeric{Valid: true}, pgtype.Float8{Valid: true}}, + {pgtype.Numeric{}, pgtype.Float8{}}, + } { + f, err := tt.n.Float64Value() + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.f, f, "%d", i) + } + + f, err := pgtype.Numeric{NaN: true, Valid: true}.Float64Value() + assert.NoError(t, err) + assert.True(t, math.IsNaN(f.Float)) + assert.True(t, f.Valid) +} + func TestNumericCodecFuzz(t *testing.T) { r := rand.New(rand.NewSource(0)) max := &big.Int{}