diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index ef3ce201..c75efa36 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -263,6 +263,16 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } + if num.NaN { + return errors.New("cannot decode 'NaN'") + } + if num.InfinityModifier == pgtype.Infinity { + return errors.New("cannot decode 'Infinity'") + } + if num.InfinityModifier == pgtype.NegativeInfinity { + return errors.New("cannot decode '-Infinity'") + } + *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present} return nil diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index e635da41..e3c6d59d 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -1,6 +1,7 @@ package numeric_test import ( + "context" "fmt" "math/big" "math/rand" @@ -93,6 +94,15 @@ func TestNumericNormalize(t *testing.T) { }) } +func TestNumericNaN(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var n shopspring.Numeric + err := conn.QueryRow(context.Background(), `select 'NaN'::numeric`).Scan(&n) + require.EqualError(t, err, `can't scan into dest[0]: cannot decode 'NaN'`) +} + func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present},