Add pgtype.Numeric

box
Jack Christensen 2017-04-01 23:33:04 -05:00
parent c5d247830c
commit 5ad2c4e2b9
5 changed files with 964 additions and 0 deletions

35
pgtype/decimal.go Normal file
View File

@ -0,0 +1,35 @@
package pgtype
import (
"io"
)
type Decimal Numeric
func (dst *Decimal) Set(src interface{}) error {
return (*Numeric)(dst).Set(src)
}
func (dst *Decimal) Get() interface{} {
return (*Numeric)(dst).Get()
}
func (src *Decimal) AssignTo(dst interface{}) error {
return (*Numeric)(src).AssignTo(dst)
}
func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error {
return (*Numeric)(dst).DecodeText(ci, src)
}
func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*Numeric)(dst).DecodeBinary(ci, src)
}
func (src *Decimal) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
return (*Numeric)(src).EncodeText(ci, w)
}
func (src *Decimal) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
return (*Numeric)(src).EncodeBinary(ci, w)
}

602
pgtype/numeric.go Normal file
View File

@ -0,0 +1,602 @@
package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"math"
"math/big"
"strconv"
"strings"
"github.com/jackc/pgx/pgio"
)
// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
const nbase = 10000
var big0 *big.Int = big.NewInt(0)
var big10 *big.Int = big.NewInt(10)
var big100 *big.Int = big.NewInt(100)
var big1000 *big.Int = big.NewInt(1000)
var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
var bigMinInt *big.Int = big.NewInt(int64(minInt))
var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
var bigNBase *big.Int = big.NewInt(nbase)
var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
type Numeric struct {
Int *big.Int
Exp int32
Status Status
}
func (dst *Numeric) Set(src interface{}) error {
if src == nil {
*dst = Numeric{Status: Null}
return nil
}
switch value := src.(type) {
case float32:
num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
if err != nil {
return err
}
*dst = Numeric{Int: num, Exp: exp, Status: Present}
case float64:
num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
if err != nil {
return err
}
*dst = Numeric{Int: num, Exp: exp, Status: Present}
case int8:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case uint8:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case int16:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case uint16:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case int32:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case uint32:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case int64:
*dst = Numeric{Int: big.NewInt(value), Status: Present}
case uint64:
*dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present}
case int:
*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
case uint:
*dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present}
case string:
num, exp, err := parseNumericString(value)
if err != nil {
return err
}
*dst = Numeric{Int: num, Exp: exp, Status: Present}
default:
if originalSrc, ok := underlyingNumberType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to Numeric", value)
}
return nil
}
func (dst *Numeric) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *Numeric) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {
case *float32:
f, err := strconv.ParseFloat(src.Int.String(), 64)
if err != nil {
return err
}
return float64AssignTo(f, src.Status, dst)
case *float64:
f, err := strconv.ParseFloat(src.Int.String(), 64)
if err != nil {
return err
}
return float64AssignTo(f, src.Status, dst)
case *int:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(bigMaxInt) > 0 {
return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
}
if normalizedInt.Cmp(bigMinInt) < 0 {
return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
}
*v = int(normalizedInt.Int64())
case *int8:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(bigMaxInt8) > 0 {
return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
}
if normalizedInt.Cmp(bigMinInt8) < 0 {
return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
}
*v = int8(normalizedInt.Int64())
case *int16:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(bigMaxInt16) > 0 {
return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
}
if normalizedInt.Cmp(bigMinInt16) < 0 {
return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
}
*v = int16(normalizedInt.Int64())
case *int32:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(bigMaxInt32) > 0 {
return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
}
if normalizedInt.Cmp(bigMinInt32) < 0 {
return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
}
*v = int32(normalizedInt.Int64())
case *int64:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(bigMaxInt64) > 0 {
return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
}
if normalizedInt.Cmp(bigMinInt64) < 0 {
return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
}
*v = normalizedInt.Int64()
case *uint:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(big0) < 0 {
return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
} else if normalizedInt.Cmp(bigMaxUint) > 0 {
return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
}
*v = uint(normalizedInt.Uint64())
case *uint8:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(big0) < 0 {
return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
} else if normalizedInt.Cmp(bigMaxUint8) > 0 {
return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
}
*v = uint8(normalizedInt.Uint64())
case *uint16:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(big0) < 0 {
return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
} else if normalizedInt.Cmp(bigMaxUint16) > 0 {
return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
}
*v = uint16(normalizedInt.Uint64())
case *uint32:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(big0) < 0 {
return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
} else if normalizedInt.Cmp(bigMaxUint32) > 0 {
return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
}
*v = uint32(normalizedInt.Uint64())
case *uint64:
normalizedInt, err := src.toBigInt()
if err != nil {
return err
}
if normalizedInt.Cmp(big0) < 0 {
return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
} else if normalizedInt.Cmp(bigMaxUint64) > 0 {
return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
}
*v = normalizedInt.Uint64()
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return nullAssignTo(dst)
}
return nil
}
func (dst *Numeric) toBigInt() (*big.Int, error) {
if dst.Exp == 0 {
return dst.Int, nil
}
num := &big.Int{}
num.Set(dst.Int)
if dst.Exp > 0 {
mul := &big.Int{}
mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil)
num.Mul(num, mul)
return num, nil
}
div := &big.Int{}
div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
remainder := &big.Int{}
num.DivMod(num, div, remainder)
if remainder.Cmp(big0) != 0 {
return nil, fmt.Errorf("cannot convert %v to integer", dst)
}
return num, nil
}
func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: Null}
return nil
}
num, exp, err := parseNumericString(string(src))
if err != nil {
return err
}
*dst = Numeric{Int: num, Exp: exp, Status: Present}
return nil
}
func parseNumericString(str string) (n *big.Int, exp int32, err error) {
parts := strings.SplitN(str, ".", 2)
digits := strings.Join(parts, "")
if len(parts) > 1 {
exp = int32(-len(parts[1]))
} else {
for len(digits) > 1 && digits[len(digits)-1] == '0' {
digits = digits[:len(digits)-1]
exp++
}
}
accum := &big.Int{}
if _, ok := accum.SetString(digits, 10); !ok {
return nil, 0, fmt.Errorf("%s is not a number", str)
}
return accum, exp, nil
}
func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: Null}
return nil
}
if len(src) < 8 {
return fmt.Errorf("numeric incomplete %v", src)
}
rp := 0
ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if ndigits == 0 {
*dst = Numeric{Int: big.NewInt(0), Status: Present}
return nil
}
weight := int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
sign := int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
dscale := int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if len(src[rp:]) < int(ndigits)*2 {
return fmt.Errorf("numeric incomplete %v", src)
}
accum := &big.Int{}
for i := 0; i < int(ndigits+3)/4; i++ {
int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
rp += bytesRead
if i > 0 {
var mul *big.Int
switch digitsRead {
case 1:
mul = bigNBase
case 2:
mul = bigNBaseX2
case 3:
mul = bigNBaseX3
case 4:
mul = bigNBaseX4
default:
return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
}
accum.Mul(accum, mul)
}
accum.Add(accum, big.NewInt(int64accum))
}
exp := (int32(weight) - int32(ndigits) + 1) * 4
if dscale > 0 {
fracNBaseDigits := ndigits - weight - 1
fracDecimalDigits := fracNBaseDigits * 4
if dscale > fracDecimalDigits {
multCount := int(dscale - fracDecimalDigits)
for i := 0; i < multCount; i++ {
accum.Mul(accum, big10)
exp--
}
} else if dscale < fracDecimalDigits {
divCount := int(fracDecimalDigits - dscale)
for i := 0; i < divCount; i++ {
accum.Div(accum, big10)
exp++
}
}
}
reduced := &big.Int{}
remainder := &big.Int{}
if exp >= 0 {
for {
reduced.DivMod(accum, big10, remainder)
if remainder.Cmp(big0) != 0 {
break
}
accum.Set(reduced)
exp++
}
}
if sign != 0 {
accum.Neg(accum)
}
*dst = Numeric{Int: accum, Exp: exp, Status: Present}
return nil
}
func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
digits := len(src) / 2
if digits > 4 {
digits = 4
}
rp := 0
for i := 0; i < digits; i++ {
if i > 0 {
accum *= nbase
}
accum += int64(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
return accum, rp, digits
}
func (src *Numeric) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) {
switch src.Status {
case Null:
return true, nil
case Undefined:
return false, errUndefined
}
if _, err := io.WriteString(w, src.Int.String()); err != nil {
return false, err
}
if err := pgio.WriteByte(w, 'e'); err != nil {
return false, err
}
if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Exp), 10)); err != nil {
return false, err
}
return false, nil
}
func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) {
switch src.Status {
case Null:
return true, nil
case Undefined:
return false, errUndefined
}
var sign int16
if src.Int.Cmp(big0) < 0 {
sign = 16384
}
absInt := &big.Int{}
wholePart := &big.Int{}
fracPart := &big.Int{}
remainder := &big.Int{}
absInt.Abs(src.Int)
// Normalize absInt and exp to where exp is always a multiple of 4. This makes
// converting to 16-bit base 10,000 digits easier.
var exp int32
switch src.Exp % 4 {
case 1, -3:
exp = src.Exp - 1
absInt.Mul(absInt, big10)
case 2, -2:
exp = src.Exp - 2
absInt.Mul(absInt, big100)
case 3, -1:
exp = src.Exp - 3
absInt.Mul(absInt, big1000)
default:
exp = src.Exp
}
if exp < 0 {
divisor := &big.Int{}
divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
wholePart.DivMod(absInt, divisor, fracPart)
} else {
wholePart = absInt
}
var wholeDigits, fracDigits []int16
for wholePart.Cmp(big0) != 0 {
wholePart.DivMod(wholePart, bigNBase, remainder)
wholeDigits = append(wholeDigits, int16(remainder.Int64()))
}
for fracPart.Cmp(big0) != 0 {
fracPart.DivMod(fracPart, bigNBase, remainder)
fracDigits = append(fracDigits, int16(remainder.Int64()))
}
if _, err := pgio.WriteInt16(w, int16(len(wholeDigits)+len(fracDigits))); err != nil {
return false, err
}
var weight int16
if len(wholeDigits) > 0 {
weight = int16(len(wholeDigits) - 1)
if exp > 0 {
weight += int16(exp / 4)
}
} else {
weight = int16(exp/4) - 1 + int16(len(fracDigits))
}
if _, err := pgio.WriteInt16(w, weight); err != nil {
return false, err
}
if _, err := pgio.WriteInt16(w, sign); err != nil {
return false, err
}
var dscale int16
if src.Exp < 0 {
dscale = int16(-src.Exp)
}
if _, err := pgio.WriteInt16(w, dscale); err != nil {
return false, err
}
for i := len(wholeDigits) - 1; i >= 0; i-- {
if _, err := pgio.WriteInt16(w, wholeDigits[i]); err != nil {
return false, err
}
}
for i := len(fracDigits) - 1; i >= 0; i-- {
if _, err := pgio.WriteInt16(w, fracDigits[i]); err != nil {
return false, err
}
}
return false, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Numeric) Scan(src interface{}) error {
if src == nil {
*dst = Numeric{Status: Null}
return nil
}
switch src := src.(type) {
case float64:
// TODO
// *dst = Numeric{Float: src, Status: Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src *Numeric) Value() (driver.Value, error) {
switch src.Status {
case Present:
buf := &bytes.Buffer{}
_, err := src.EncodeText(nil, buf)
if err != nil {
return nil, err
}
return buf.String(), nil
case Null:
return nil, nil
default:
return nil, errUndefined
}
}

315
pgtype/numeric_test.go Normal file
View File

@ -0,0 +1,315 @@
package pgtype_test
import (
"math/big"
"math/rand"
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
)
// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0)
func numericEqual(left, right *pgtype.Numeric) bool {
return left.Status == right.Status &&
left.Exp == right.Exp &&
((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0))
}
// For test purposes only.
func numericNormalizedEqual(left, right *pgtype.Numeric) bool {
if left.Status != right.Status {
return false
}
normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status}
normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status}
if left.Exp < right.Exp {
mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil)
normRight.Int.Mul(normRight.Int, mul)
} else if left.Exp > right.Exp {
mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil)
normLeft.Int.Mul(normLeft.Int, mul)
}
return normLeft.Int.Cmp(normRight.Int) == 0
}
func mustParseBigInt(t *testing.T, src string) *big.Int {
i := &big.Int{}
if _, ok := i.SetString(src, 10); !ok {
t.Fatalf("could not parse big.Int: %s", src)
}
return i
}
func TestNumericNormalize(t *testing.T) {
testSuccessfulNormalize(t, []normalizeTest{
{
sql: "select '0'::numeric",
value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present},
},
{
sql: "select '1'::numeric",
value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present},
},
{
sql: "select '10.00'::numeric",
value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present},
},
{
sql: "select '1e-3'::numeric",
value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present},
},
{
sql: "select '-1'::numeric",
value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present},
},
{
sql: "select '10000'::numeric",
value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present},
},
{
sql: "select '3.14'::numeric",
value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present},
},
{
sql: "select '1.1'::numeric",
value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present},
},
{
sql: "select '100010001'::numeric",
value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present},
},
{
sql: "select '100010001.0001'::numeric",
value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present},
},
{
sql: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric",
value: pgtype.Numeric{
Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"),
Exp: -41,
Status: pgtype.Present,
},
},
{
sql: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric",
value: pgtype.Numeric{
Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"),
Exp: -196,
Status: pgtype.Present,
},
},
{
sql: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric",
value: pgtype.Numeric{
Int: mustParseBigInt(t, "123"),
Exp: -186,
Status: pgtype.Present,
},
},
})
}
func TestNumericTranscode(t *testing.T) {
testSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{
&pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present},
// preserves significant zeroes
&pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present},
&pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present},
&pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present},
&pgtype.Numeric{Status: pgtype.Null},
}, func(aa, bb interface{}) bool {
a := aa.(pgtype.Numeric)
b := bb.(pgtype.Numeric)
return numericEqual(&a, &b)
})
}
func TestNumericTranscodeFuzz(t *testing.T) {
r := rand.New(rand.NewSource(0))
max := &big.Int{}
max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10)
values := make([]interface{}, 0, 2000)
for i := 0; i < 10; i++ {
for j := -50; j < 50; j++ {
num := (&big.Int{}).Rand(r, max)
negNum := &big.Int{}
negNum.Neg(num)
values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present})
values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present})
}
}
testSuccessfulTranscodeEqFunc(t, "numeric", values,
func(aa, bb interface{}) bool {
a := aa.(pgtype.Numeric)
b := bb.(pgtype.Numeric)
return numericNormalizedEqual(&a, &b)
})
}
func TestNumericSet(t *testing.T) {
successfulTests := []struct {
source interface{}
result *pgtype.Numeric
}{
{source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}},
{source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}},
{source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}},
{source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}},
{source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}},
{source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}},
{source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}},
{source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}},
{source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}},
}
for i, tt := range successfulTests {
r := &pgtype.Numeric{}
err := r.Set(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !numericEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestNumericAssignTo(t *testing.T) {
var i8 int8
var i16 int16
var i32 int32
var i64 int64
var i int
var ui8 uint8
var ui16 uint16
var ui32 uint32
var ui64 uint64
var ui uint
var pi8 *int8
var _i8 _int8
var _pi8 *_int8
var f32 float32
var f64 float64
var pf32 *float32
var pf64 *float64
simpleTests := []struct {
src *pgtype.Numeric
dst interface{}
expected interface{}
}{
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
pointerAllocTests := []struct {
src *pgtype.Numeric
dst interface{}
expected interface{}
}{
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)},
}
for i, tt := range pointerAllocTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
errorTests := []struct {
src *pgtype.Numeric
dst interface{}
}{
{src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8},
{src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16},
{src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8},
{src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16},
{src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32},
{src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64},
{src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui},
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32},
}
for i, tt := range errorTests {
err := tt.src.AssignTo(tt.dst)
if err == nil {
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
}
}
}

View File

@ -228,6 +228,7 @@ func init() {
"cidr": &Cidr{},
"date": &Date{},
"daterange": &Daterange{},
"decimal": &Decimal{},
"float4": &Float4{},
"float8": &Float8{},
"hstore": &Hstore{},
@ -240,6 +241,7 @@ func init() {
"json": &Json{},
"jsonb": &Jsonb{},
"name": &Name{},
"numeric": &Numeric{},
"oid": &OidValue{},
"record": &Record{},
"text": &Text{},

View File

@ -118,6 +118,16 @@ func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interfac
if dt, ok := ci.DataTypeForOid(oid); ok {
if _, ok := dt.Value.(pgtype.BinaryEncoder); ok {
if arg, ok := arg.(driver.Valuer); ok {
if err := dt.Value.Set(arg); err != nil {
if value, err := arg.Value(); err == nil {
if _, ok := value.(string); ok {
return TextFormatCode
}
}
}
}
return BinaryFormatCode
}
}