mirror of https://github.com/jackc/pgx.git
Add Numeric.MarshalJSON
parent
55195b3a64
commit
0d9bd0366b
39
numeric.go
39
numeric.go
|
@ -1,6 +1,7 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
@ -807,3 +808,41 @@ func (src Numeric) Value() (driver.Value, error) {
|
|||
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func (src Numeric) MarshalJSON() ([]byte, error) {
|
||||
if !src.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
if src.NaN {
|
||||
return []byte(`"NaN"`), nil
|
||||
}
|
||||
|
||||
intStr := src.Int.String()
|
||||
buf := &bytes.Buffer{}
|
||||
exp := int(src.Exp)
|
||||
if exp > 0 {
|
||||
buf.WriteString(intStr)
|
||||
for i := 0; i < exp; i++ {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
} else if exp < 0 {
|
||||
if len(intStr) <= -exp {
|
||||
buf.WriteString("0.")
|
||||
leadingZeros := -exp - len(intStr)
|
||||
for i := 0; i < leadingZeros; i++ {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
buf.WriteString(intStr)
|
||||
} else if len(intStr) > -exp {
|
||||
dpPos := len(intStr) + exp
|
||||
buf.WriteString(intStr[:dpPos])
|
||||
buf.WriteByte('.')
|
||||
buf.WriteString(intStr[dpPos:])
|
||||
}
|
||||
} else {
|
||||
buf.WriteString(intStr)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
|
@ -9,6 +11,7 @@ import (
|
|||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgtype/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0)
|
||||
|
@ -410,3 +413,35 @@ func TestNumericEncodeDecodeBinary(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumericMarshalJSON(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
for i, tt := range []struct {
|
||||
decString string
|
||||
}{
|
||||
{"NaN"},
|
||||
{"0"},
|
||||
{"1"},
|
||||
{"-1"},
|
||||
{"1000000000000000000"},
|
||||
{"1234.56789"},
|
||||
{"1.56789"},
|
||||
{"0.00000000000056789"},
|
||||
{"0.00123000"},
|
||||
{"123e-3"},
|
||||
{"243723409723490243842378942378901237502734019231380123e23790"},
|
||||
{"3409823409243892349028349023482934092340892390101e-14021"},
|
||||
} {
|
||||
var num pgtype.Numeric
|
||||
var pgJSON string
|
||||
err := conn.QueryRow(context.Background(), `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON)
|
||||
require.NoErrorf(t, err, "%d", i)
|
||||
|
||||
goJSON, err := json.Marshal(num)
|
||||
require.NoErrorf(t, err, "%d", i)
|
||||
|
||||
require.Equal(t, pgJSON, string(goJSON))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue