mirror of https://github.com/jackc/pgx.git
add UnmarshalJSON for pgtype Numeric
parent
384a581e99
commit
766d2bba4f
|
@ -240,6 +240,18 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
|
||||||
return n.numberTextBytes(), nil
|
return n.numberTextBytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Numeric) UnmarshalJSON(src []byte) error {
|
||||||
|
if bytes.Compare(src, []byte(`null`)) == 0 {
|
||||||
|
*n = Numeric{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if bytes.Compare(src, []byte(`"NaN"`)) == 0 {
|
||||||
|
*n = Numeric{NaN: true, Valid: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return scanPlanTextAnyToNumericScanner{}.Scan(src, n)
|
||||||
|
}
|
||||||
|
|
||||||
// numberString returns a string of the number. undefined if NaN, infinite, or NULL
|
// numberString returns a string of the number. undefined if NaN, infinite, or NULL
|
||||||
func (n Numeric) numberTextBytes() []byte {
|
func (n Numeric) numberTextBytes() []byte {
|
||||||
intStr := n.Int.String()
|
intStr := n.Int.String()
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"math/big"
|
"math/big"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -232,3 +233,72 @@ func TestNumericMarshalJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNumericUnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *pgtype.Numeric
|
||||||
|
src []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "null",
|
||||||
|
want: &pgtype.Numeric{},
|
||||||
|
src: []byte(`null`),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NaN",
|
||||||
|
want: &pgtype.Numeric{Valid: true, NaN: true},
|
||||||
|
src: []byte(`"NaN"`),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "0",
|
||||||
|
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(0)},
|
||||||
|
src: []byte("0"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1",
|
||||||
|
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1)},
|
||||||
|
src: []byte("1"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "-1",
|
||||||
|
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(-1)},
|
||||||
|
src: []byte("-1"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bigInt",
|
||||||
|
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(1), Exp: 30},
|
||||||
|
src: []byte("1000000000000000000000000000000"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "float: 1234.56789",
|
||||||
|
want: &pgtype.Numeric{Valid: true, Int: big.NewInt(123456789), Exp: -5},
|
||||||
|
src: []byte("1234.56789"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid value",
|
||||||
|
want: &pgtype.Numeric{},
|
||||||
|
src: []byte("0xffff"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &pgtype.Numeric{}
|
||||||
|
if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue