diff --git a/pgtype/decimal.go b/pgtype/decimal.go new file mode 100644 index 00000000..728c748e --- /dev/null +++ b/pgtype/decimal.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Decimal Numeric + +func (dst *Decimal) Set(src interface{}) error { + return (*Numeric)(dst).Set(src) +} + +func (dst *Decimal) Get() interface{} { + return (*Numeric)(dst).Get() +} + +func (src *Decimal) AssignTo(dst interface{}) error { + return (*Numeric)(src).AssignTo(dst) +} + +func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeText(ci, src) +} + +func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeBinary(ci, src) +} + +func (src *Decimal) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Numeric)(src).EncodeText(ci, w) +} + +func (src *Decimal) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Numeric)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/numeric.go b/pgtype/numeric.go new file mode 100644 index 00000000..0f3f6529 --- /dev/null +++ b/pgtype/numeric.go @@ -0,0 +1,602 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +var big0 *big.Int = big.NewInt(0) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type Numeric struct { + Int *big.Int + Exp int32 + Status Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch value := src.(type) { + case float32: + num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case float64: + num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case int8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int64: + *dst = Numeric{Int: big.NewInt(value), Status: Present} + case uint64: + *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} + case int: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint: + *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} + case string: + num, exp, err := parseNumericString(value) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *float32: + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *float64: + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) toBigInt() (*big.Int, error) { + if dst.Exp == 0 { + return dst.Int, nil + } + + num := &big.Int{} + num.Set(dst.Int) + if dst.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, fmt.Errorf("cannot convert %v to integer", dst) + } + return num, nil +} + +func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Int: num, Exp: exp, Status: Present} + return nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if len(src) < 8 { + return fmt.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if len(src[rp:]) < int(ndigits)*2 { + return fmt.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := ndigits - weight - 1 + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + *dst = Numeric{Int: accum, Exp: exp, Status: Present} + + return nil + +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +func (src *Numeric) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := io.WriteString(w, src.Int.String()); err != nil { + return false, err + } + + if err := pgio.WriteByte(w, 'e'); err != nil { + return false, err + } + + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Exp), 10)); err != nil { + return false, err + } + + return false, nil + +} + +func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var sign int16 + if src.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(src.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch src.Exp % 4 { + case 1, -3: + exp = src.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = src.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = src.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = src.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + for fracPart.Cmp(big0) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + + if _, err := pgio.WriteInt16(w, int16(len(wholeDigits)+len(fracDigits))); err != nil { + return false, err + } + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + if _, err := pgio.WriteInt16(w, weight); err != nil { + return false, err + } + + if _, err := pgio.WriteInt16(w, sign); err != nil { + return false, err + } + + var dscale int16 + if src.Exp < 0 { + dscale = int16(-src.Exp) + } + if _, err := pgio.WriteInt16(w, dscale); err != nil { + return false, err + } + + for i := len(wholeDigits) - 1; i >= 0; i-- { + if _, err := pgio.WriteInt16(w, wholeDigits[i]); err != nil { + return false, err + } + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + if _, err := pgio.WriteInt16(w, fracDigits[i]); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + // TODO + // *dst = Numeric{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case Present: + buf := &bytes.Buffer{} + _, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + + return buf.String(), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go new file mode 100644 index 00000000..64dea847 --- /dev/null +++ b/pgtype/numeric_test.go @@ -0,0 +1,315 @@ +package pgtype_test + +import ( + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) +func numericEqual(left, right *pgtype.Numeric) bool { + return left.Status == right.Status && + left.Exp == right.Exp && + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) +} + +// For test purposes only. +func numericNormalizedEqual(left, right *pgtype.Numeric) bool { + if left.Status != right.Status { + return false + } + + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} + + if left.Exp < right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) + normRight.Int.Mul(normRight.Int, mul) + } else if left.Exp > right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) + normLeft.Int.Mul(normLeft.Int, mul) + } + + return normLeft.Int.Cmp(normRight.Int) == 0 +} + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func TestNumericNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '0'::numeric", + value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '10.00'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + }, + { + sql: "select '1e-3'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + }, + { + sql: "select '-1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '10000'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + }, + { + sql: "select '3.14'::numeric", + value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + }, + { + sql: "select '1.1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + }, + { + sql: "select '100010001'::numeric", + value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '100010001.0001'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + }, + { + sql: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Status: pgtype.Present, + }, + }, + { + sql: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Status: pgtype.Present, + }, + }, + { + sql: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "123"), + Exp: -186, + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, + + // preserves significant zeroes + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + &pgtype.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericEqual(&a, &b) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + negNum := &big.Int{} + negNum.Neg(num) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + } + } + + testSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericNormalizedEqual(&a, &b) + }) +} + +func TestNumericSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result *pgtype.Numeric + }{ + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &pgtype.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !numericEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *pgtype.Numeric + dst interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3d691044..84939b58 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -228,6 +228,7 @@ func init() { "cidr": &Cidr{}, "date": &Date{}, "daterange": &Daterange{}, + "decimal": &Decimal{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, @@ -240,6 +241,7 @@ func init() { "json": &Json{}, "jsonb": &Jsonb{}, "name": &Name{}, + "numeric": &Numeric{}, "oid": &OidValue{}, "record": &Record{}, "text": &Text{}, diff --git a/values.go b/values.go index 5370bf47..71c4cc5c 100644 --- a/values.go +++ b/values.go @@ -118,6 +118,16 @@ func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interfac if dt, ok := ci.DataTypeForOid(oid); ok { if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { + if arg, ok := arg.(driver.Valuer); ok { + if err := dt.Value.Set(arg); err != nil { + if value, err := arg.Value(); err == nil { + if _, ok := value.(string); ok { + return TextFormatCode + } + } + } + } + return BinaryFormatCode } }