mirror of https://github.com/jackc/pgx.git
add UnmarshalJSON for pgtype Numeric
parent
384a581e99
commit
766d2bba4f
|
@ -240,6 +240,18 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
|
|||
return n.numberTextBytes(), nil
|
||||
}
|
||||
|
||||
func (n *Numeric) UnmarshalJSON(src []byte) error {
|
||||
if bytes.Compare(src, []byte(`null`)) == 0 {
|
||||
*n = Numeric{}
|
||||
return nil
|
||||
}
|
||||
if bytes.Compare(src, []byte(`"NaN"`)) == 0 {
|
||||
*n = Numeric{NaN: true, Valid: true}
|
||||
return nil
|
||||
}
|
||||
return scanPlanTextAnyToNumericScanner{}.Scan(src, n)
|
||||
}
|
||||
|
||||
// numberString returns a string of the number. undefined if NaN, infinite, or NULL
|
||||
func (n Numeric) numberTextBytes() []byte {
|
||||
intStr := n.Int.String()
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"math"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
|
@ -232,3 +233,72 @@ func TestNumericMarshalJSON(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNumericUnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want *pgtype.Numeric
|
||||
src []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "null",
|
||||
want: &pgtype.Numeric{},
|
||||
src: []byte(`null`),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "NaN",
|
||||
want: &pgtype.Numeric{Valid: true, NaN: true},
|
||||
src: []byte(`"NaN"`),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "0",
|
||||
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(0)},
|
||||
src: []byte("0"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "1",
|
||||
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1)},
|
||||
src: []byte("1"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "-1",
|
||||
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(-1)},
|
||||
src: []byte("-1"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bigInt",
|
||||
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1), Exp: 30},
|
||||
src: []byte("1000000000000000000000000000000"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "float: 1234.56789",
|
||||
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(123456789), Exp: -5},
|
||||
src: []byte("1234.56789"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
want: &pgtype.Numeric{},
|
||||
src: []byte("0xffff"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &pgtype.Numeric{}
|
||||
if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr {
|
||||
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue