diff --git a/numeric.go b/numeric.go index f5260548..3f2dc9ae 100644 --- a/numeric.go +++ b/numeric.go @@ -49,10 +49,11 @@ var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) type Numeric struct { - Int *big.Int - Exp int32 - Status Status - NaN bool + Int *big.Int + Exp int32 + Status Status + NaN bool + InfinityModifier InfinityModifier } func (dst *Numeric) Set(src interface{}) error { @@ -73,6 +74,12 @@ func (dst *Numeric) Set(src interface{}) error { if math.IsNaN(float64(value)) { *dst = Numeric{Status: Present, NaN: true} return nil + } else if math.IsInf(float64(value), 1) { + *dst = Numeric{Status: Present, InfinityModifier: Infinity} + return nil + } else if math.IsInf(float64(value), -1) { + *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + return nil } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) if err != nil { @@ -83,6 +90,12 @@ func (dst *Numeric) Set(src interface{}) error { if math.IsNaN(value) { *dst = Numeric{Status: Present, NaN: true} return nil + } else if math.IsInf(value, 1) { + *dst = Numeric{Status: Present, InfinityModifier: Infinity} + return nil + } else if math.IsInf(value, -1) { + *dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity} + return nil } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) if err != nil { @@ -193,6 +206,8 @@ func (dst *Numeric) Set(src interface{}) error { } else { return dst.Set(*value) } + case InfinityModifier: + *dst = Numeric{InfinityModifier: value, Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) @@ -206,6 +221,9 @@ func (dst *Numeric) Set(src interface{}) error { func (dst Numeric) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst case Null: return nil diff --git a/numeric_test.go b/numeric_test.go index fff5a2e0..f14cf960 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -222,6 +222,12 @@ func TestNumericSet(t *testing.T) { {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, NaN: true}}, + {source: pgtype.Infinity, result: &pgtype.Numeric{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: math.Inf(1), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, + {source: pgtype.NegativeInfinity, result: &pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + {source: math.Inf(-1), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.NegativeInfinity}}, + {source: float32(math.Inf(1)), result: &pgtype.Numeric{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}}, } for i, tt := range successfulTests {