diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 9d51a7e5..376c03fe 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -240,6 +240,18 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return n.numberTextBytes(), nil } +func (n *Numeric) UnmarshalJSON(src []byte) error { + if bytes.Compare(src, []byte(`null`)) == 0 { + *n = Numeric{} + return nil + } + if bytes.Compare(src, []byte(`"NaN"`)) == 0 { + *n = Numeric{NaN: true, Valid: true} + return nil + } + return scanPlanTextAnyToNumericScanner{}.Scan(src, n) +} + // numberString returns a string of the number. undefined if NaN, infinite, or NULL func (n Numeric) numberTextBytes() []byte { intStr := n.Int.String() diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index f0591246..691cc979 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -6,6 +6,7 @@ import ( "math" "math/big" "math/rand" + "reflect" "strconv" "testing" @@ -232,3 +233,72 @@ func TestNumericMarshalJSON(t *testing.T) { } }) } + +func TestNumericUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.Numeric + src []byte + wantErr bool + }{ + { + name: "null", + want: &pgtype.Numeric{}, + src: []byte(`null`), + wantErr: false, + }, + { + name: "NaN", + want: &pgtype.Numeric{Valid: true, NaN: true}, + src: []byte(`"NaN"`), + wantErr: false, + }, + { + name: "0", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(0)}, + src: []byte("0"), + wantErr: false, + }, + { + name: "1", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1)}, + src: []byte("1"), + wantErr: false, + }, + { + name: "-1", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(-1)}, + src: []byte("-1"), + wantErr: false, + }, + { + name: "bigInt", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1), Exp: 30}, + src: []byte("1000000000000000000000000000000"), + wantErr: false, + }, + { + name: "float: 1234.56789", + want: &pgtype.Numeric{Valid: true, Int: big.NewInt(123456789), Exp: -5}, + src: []byte("1234.56789"), + wantErr: false, + }, + { + name: "invalid value", + want: &pgtype.Numeric{}, + src: []byte("0xffff"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.Numeric{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +}