add UnmarshalJSON for pgtype Numeric

pull/1490/head
Yumin Xia 2023-01-29 21:14:49 -08:00 committed by Jack Christensen
parent 384a581e99
commit 766d2bba4f
2 changed files with 82 additions and 0 deletions

View File

@ -240,6 +240,18 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
return n.numberTextBytes(), nil
}
func (n *Numeric) UnmarshalJSON(src []byte) error {
if bytes.Compare(src, []byte(`null`)) == 0 {
*n = Numeric{}
return nil
}
if bytes.Compare(src, []byte(`"NaN"`)) == 0 {
*n = Numeric{NaN: true, Valid: true}
return nil
}
return scanPlanTextAnyToNumericScanner{}.Scan(src, n)
}
// numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte {
intStr := n.Int.String()

View File

@ -6,6 +6,7 @@ import (
"math"
"math/big"
"math/rand"
"reflect"
"strconv"
"testing"
@ -232,3 +233,72 @@ func TestNumericMarshalJSON(t *testing.T) {
}
})
}
func TestNumericUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
want *pgtype.Numeric
src []byte
wantErr bool
}{
{
name: "null",
want: &pgtype.Numeric{},
src: []byte(`null`),
wantErr: false,
},
{
name: "NaN",
want: &pgtype.Numeric{Valid: true, NaN: true},
src: []byte(`"NaN"`),
wantErr: false,
},
{
name: "0",
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(0)},
src: []byte("0"),
wantErr: false,
},
{
name: "1",
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1)},
src: []byte("1"),
wantErr: false,
},
{
name: "-1",
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(-1)},
src: []byte("-1"),
wantErr: false,
},
{
name: "bigInt",
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1), Exp: 30},
src: []byte("1000000000000000000000000000000"),
wantErr: false,
},
{
name: "float: 1234.56789",
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(123456789), Exp: -5},
src: []byte("1234.56789"),
wantErr: false,
},
{
name: "invalid value",
want: &pgtype.Numeric{},
src: []byte("0xffff"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &pgtype.Numeric{}
if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
}
})
}
}