diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go deleted file mode 100644 index 3a9d99ba..00000000 --- a/ext/shopspring-numeric/decimal.go +++ /dev/null @@ -1,329 +0,0 @@ -package numeric - -import ( - "database/sql/driver" - "fmt" - "strconv" - - "github.com/jackc/pgtype" - "github.com/shopspring/decimal" -) - -type Numeric struct { - Decimal decimal.Decimal - Valid bool -} - -func (dst *Numeric) Set(src interface{}) error { - if src == nil { - *dst = Numeric{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case decimal.Decimal: - *dst = Numeric{Decimal: value, Valid: true} - case decimal.NullDecimal: - if value.Valid { - *dst = Numeric{Decimal: value.Decimal, Valid: true} - } else { - *dst = Numeric{} - } - case float32: - *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Valid: true} - case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(value), Valid: true} - case int8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint8: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case int16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint16: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case int32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint32: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case int64: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint64: - // uint64 could be greater than int64 so convert to string then to decimal - dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Valid: true} - case int: - *dst = Numeric{Decimal: decimal.New(int64(value), 0), Valid: true} - case uint: - // uint could be greater than int64 so convert to string then to decimal - dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Valid: true} - case string: - dec, err := decimal.NewFromString(value) - if err != nil { - return err - } - *dst = Numeric{Decimal: dec, Valid: true} - default: - // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. - num := &pgtype.Numeric{} - if err := num.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) - } - - buf, err := num.EncodeText(nil, nil) - if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) - } - - dec, err := decimal.NewFromString(string(buf)) - if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) - } - *dst = Numeric{Decimal: dec, Valid: true} - } - - return nil -} - -func (dst Numeric) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.Decimal -} - -func (src *Numeric) AssignTo(dst interface{}) error { - if !src.Valid { - if v, ok := dst.(*decimal.NullDecimal); ok { - (*v).Valid = false - (*v).Decimal = src.Decimal - return nil - } - return pgtype.NullAssignTo(dst) - } - - switch v := dst.(type) { - case *decimal.Decimal: - *v = src.Decimal - case *decimal.NullDecimal: - (*v).Valid = true - (*v).Decimal = src.Decimal - case *float32: - f, _ := src.Decimal.Float64() - *v = float32(f) - case *float64: - f, _ := src.Decimal.Float64() - *v = f - case *int: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int(n) - case *int8: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int8(n) - case *int16: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int16(n) - case *int32: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int32(n) - case *int64: - if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = int64(n) - case *uint: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint(n) - case *uint8: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint8(n) - case *uint16: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint16(n) - case *uint32: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint32(n) - case *uint64: - if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) - if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) - } - *v = uint64(n) - default: - if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } - - return nil -} - -func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Numeric{} - return nil - } - - dec, err := decimal.NewFromString(string(src)) - if err != nil { - return err - } - - *dst = Numeric{Decimal: dec, Valid: true} - return nil -} - -func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = Numeric{} - return nil - } - - // For now at least, implement this in terms of pgtype.Numeric - - num := &pgtype.Numeric{} - if err := num.DecodeBinary(ci, src); err != nil { - return err - } - - *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Valid: true} - - return nil -} - -func (src Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - return append(buf, src.Decimal.String()...), nil -} - -func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - // For now at least, implement this in terms of pgtype.Numeric - num := &pgtype.Numeric{} - if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { - return nil, err - } - - return num.EncodeBinary(ci, buf) -} - -// Scan implements the database/sql Scanner interface. -func (dst *Numeric) Scan(src interface{}) error { - if src == nil { - *dst = Numeric{} - return nil - } - - switch src := src.(type) { - case float64: - *dst = Numeric{Decimal: decimal.NewFromFloat(src), Valid: true} - return nil - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - return dst.DecodeText(nil, src) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Numeric) Value() (driver.Value, error) { - if !src.Valid { - return nil, nil - } - return src.Decimal.Value() -} - -func (src Numeric) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return src.Decimal.MarshalJSON() -} - -func (dst *Numeric) UnmarshalJSON(b []byte) error { - d := decimal.NullDecimal{} - err := d.UnmarshalJSON(b) - if err != nil { - return err - } - - *dst = Numeric{Decimal: d.Decimal, Valid: d.Valid} - - return nil -} diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go deleted file mode 100644 index d130a69a..00000000 --- a/ext/shopspring-numeric/decimal_test.go +++ /dev/null @@ -1,360 +0,0 @@ -package numeric_test - -import ( - "fmt" - "math/big" - "math/rand" - "reflect" - "testing" - - shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" - "github.com/jackc/pgtype/testutil" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/require" -) - -func mustParseDecimal(t *testing.T, src string) decimal.Decimal { - dec, err := decimal.NewFromString(src) - if err != nil { - t.Fatal(err) - } - return dec -} - -func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ - { - SQL: "select '0'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, - }, - { - SQL: "select '1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, - }, - { - SQL: "select '10.00'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Valid: true}, - }, - { - SQL: "select '1e-3'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, - }, - { - SQL: "select '-1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, - }, - { - SQL: "select '10000'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Valid: true}, - }, - { - SQL: "select '3.14'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, - }, - { - SQL: "select '1.1'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Valid: true}, - }, - { - SQL: "select '100010001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Valid: true}, - }, - { - SQL: "select '100010001.0001'::numeric", - Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Valid: true}, - }, - { - SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), - Valid: true, - }, - }, - { - SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), - Valid: true, - }, - }, - { - SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: &shopspring.Numeric{ - Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), - Valid: true, - }, - }, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericTranscode(t *testing.T) { - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Valid: true}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Valid: true}, - - &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Valid: true}, - &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Valid: true}, - &shopspring.Numeric{}, - }, func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) - }) - -} - -func TestNumericTranscodeFuzz(t *testing.T) { - r := rand.New(rand.NewSource(0)) - max := &big.Int{} - max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) - - values := make([]interface{}, 0, 2000) - for i := 0; i < 500; i++ { - num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) - negNum := "-" + num - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Valid: true}) - values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Valid: true}) - } - - testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, - func(aa, bb interface{}) bool { - a := aa.(shopspring.Numeric) - b := bb.(shopspring.Numeric) - - return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) - }) -} - -func TestNumericSet(t *testing.T) { - type _int8 int8 - - successfulTests := []struct { - source interface{} - result *shopspring.Numeric - }{ - {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{}}, - {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, - {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, - {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Valid: true}}, - {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Valid: true}}, - {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Valid: true}}, - {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Valid: true}}, - } - - for i, tt := range successfulTests { - r := &shopspring.Numeric{} - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !(r.Valid == tt.result.Valid && r.Decimal.Equal(tt.result.Decimal)) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestNumericAssignTo(t *testing.T) { - type _int8 int8 - - var i8 int8 - var i16 int16 - var i32 int32 - var i64 int64 - var i int - var ui8 uint8 - var ui16 uint16 - var ui32 uint32 - var ui64 uint64 - var ui uint - var pi8 *int8 - var _i8 _int8 - var _pi8 *_int8 - var f32 float32 - var f64 float64 - var pf32 *float32 - var pf64 *float64 - var d decimal.Decimal - var nd decimal.NullDecimal - - simpleTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f64, expected: float64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f32, expected: float32(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f64, expected: float64(4.2)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i16, expected: int16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i32, expected: int32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i64, expected: int64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &i64, expected: int64(42000)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i, expected: int(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui8, expected: uint8(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui16, expected: uint16(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui32, expected: uint32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui64, expected: uint64(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui, expected: uint(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &_i8, expected: _int8(42)}, - {src: &shopspring.Numeric{}, dst: &pi8, expected: ((*int8)(nil))}, - {src: &shopspring.Numeric{}, dst: &_pi8, expected: ((*_int8)(nil))}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &d, expected: decimal.New(42, 0)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &d, expected: decimal.New(42, 3)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Valid: true}, dst: &d, expected: decimal.New(42, -3)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, - {src: &shopspring.Numeric{}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, - } - - for i, tt := range simpleTests { - // Zero out the destination variable - reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem())) - - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this - // we end up checking reference equality on the *big.Int they contain. - switch dst := tt.dst.(type) { - case *decimal.Decimal: - if !dst.Equal(tt.expected.(decimal.Decimal)) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d) - } - case *decimal.NullDecimal: - expected := tt.expected.(decimal.NullDecimal) - - if dst.Valid != expected.Valid { - t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid) - } - if !dst.Decimal.Equal(expected.Decimal) { - t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal) - } - default: - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - } - - pointerAllocTests := []struct { - src *shopspring.Numeric - dst interface{} - expected interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf32, expected: float32(42)}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf64, expected: float64(42)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src *shopspring.Numeric - dst interface{} - }{ - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Valid: true}, dst: &i8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Valid: true}, dst: &i16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui8}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui16}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui32}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui64}, - {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui}, - {src: &shopspring.Numeric{}, dst: &i32}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} - -func BenchmarkDecode(b *testing.B) { - benchmarks := []struct { - name string - numberStr string - }{ - {"Zero", "0"}, - {"Small", "12345"}, - {"Medium", "12345.12345"}, - {"Large", "123457890.1234567890"}, - {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, - } - - for _, bm := range benchmarks { - src := &shopspring.Numeric{} - err := src.Set(bm.numberStr) - require.NoError(b, err) - textFormat, err := src.EncodeText(nil, nil) - require.NoError(b, err) - binaryFormat, err := src.EncodeBinary(nil, nil) - require.NoError(b, err) - - b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { - dst := &shopspring.Numeric{} - for i := 0; i < b.N; i++ { - err := dst.DecodeText(nil, textFormat) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { - dst := &shopspring.Numeric{} - for i := 0; i < b.N; i++ { - err := dst.DecodeBinary(nil, binaryFormat) - if err != nil { - b.Fatal(err) - } - } - }) - } -}