package numeric_test import ( "fmt" "math/big" "math/rand" "reflect" "testing" shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" "github.com/jackc/pgtype/testutil" "github.com/shopspring/decimal" "github.com/stretchr/testify/require" ) func mustParseDecimal(t *testing.T, src string) decimal.Decimal { dec, err := decimal.NewFromString(src) if err != nil { t.Fatal(err) } return dec } func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, }, { SQL: "select '1'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, }, { SQL: "select '10.00'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Valid: true}, }, { SQL: "select '1e-3'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, }, { SQL: "select '-1'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, }, { SQL: "select '10000'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Valid: true}, }, { SQL: "select '3.14'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, }, { SQL: "select '1.1'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Valid: true}, }, { SQL: "select '100010001'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Valid: true}, }, { SQL: "select '100010001.0001'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Valid: true}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), Valid: true, }, }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Valid: true, }, }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), Valid: true, }, }, }, func(aa, bb interface{}) bool { a := aa.(shopspring.Numeric) b := bb.(shopspring.Numeric) return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) }) } func TestNumericTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Valid: true}, &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Valid: true}, &shopspring.Numeric{}, }, func(aa, bb interface{}) bool { a := aa.(shopspring.Numeric) b := bb.(shopspring.Numeric) return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) }) } func TestNumericTranscodeFuzz(t *testing.T) { r := rand.New(rand.NewSource(0)) max := &big.Int{} max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) values := make([]interface{}, 0, 2000) for i := 0; i < 500; i++ { num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) negNum := "-" + num values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Valid: true}) values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Valid: true}) } testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, func(aa, bb interface{}) bool { a := aa.(shopspring.Numeric) b := bb.(shopspring.Numeric) return a.Valid == b.Valid && a.Decimal.Equal(b.Decimal) }) } func TestNumericSet(t *testing.T) { type _int8 int8 successfulTests := []struct { source interface{} result *shopspring.Numeric }{ {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{}}, {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}}, {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Valid: true}}, {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Valid: true}}, {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Valid: true}}, {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Valid: true}}, {source: float64(1.25), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.25"), Valid: true}}, } for i, tt := range successfulTests { r := &shopspring.Numeric{} err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } if !(r.Valid == tt.result.Valid && r.Decimal.Equal(tt.result.Decimal)) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } } } func TestNumericAssignTo(t *testing.T) { type _int8 int8 var i8 int8 var i16 int16 var i32 int32 var i64 int64 var i int var ui8 uint8 var ui16 uint16 var ui32 uint32 var ui64 uint64 var ui uint var pi8 *int8 var _i8 _int8 var _pi8 *_int8 var f32 float32 var f64 float64 var pf32 *float32 var pf64 *float64 var d decimal.Decimal var nd decimal.NullDecimal simpleTests := []struct { src *shopspring.Numeric dst interface{} expected interface{} }{ {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f32, expected: float32(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &f64, expected: float64(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f32, expected: float32(4.2)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Valid: true}, dst: &f64, expected: float64(4.2)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i16, expected: int16(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i32, expected: int32(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i64, expected: int64(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &i64, expected: int64(42000)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &i, expected: int(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui8, expected: uint8(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui16, expected: uint16(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui32, expected: uint32(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui64, expected: uint64(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &ui, expected: uint(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &_i8, expected: _int8(42)}, {src: &shopspring.Numeric{}, dst: &pi8, expected: ((*int8)(nil))}, {src: &shopspring.Numeric{}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &d, expected: decimal.New(42, 0)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Valid: true}, dst: &d, expected: decimal.New(42, 3)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Valid: true}, dst: &d, expected: decimal.New(42, -3)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, {src: &shopspring.Numeric{}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, } for i, tt := range simpleTests { // Zero out the destination variable reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem())) err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) } // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this // we end up checking reference equality on the *big.Int they contain. switch dst := tt.dst.(type) { case *decimal.Decimal: if !dst.Equal(tt.expected.(decimal.Decimal)) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d) } case *decimal.NullDecimal: expected := tt.expected.(decimal.NullDecimal) if dst.Valid != expected.Valid { t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid) } if !dst.Decimal.Equal(expected.Decimal) { t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal) } default: if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } } pointerAllocTests := []struct { src *shopspring.Numeric dst interface{} expected interface{} }{ {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf32, expected: float32(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Valid: true}, dst: &pf64, expected: float64(42)}, } for i, tt := range pointerAllocTests { err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) } if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } errorTests := []struct { src *shopspring.Numeric dst interface{} }{ {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Valid: true}, dst: &i8}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Valid: true}, dst: &i16}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui8}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui16}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui32}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui64}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Valid: true}, dst: &ui}, {src: &shopspring.Numeric{}, dst: &i32}, } for i, tt := range errorTests { err := tt.src.AssignTo(tt.dst) if err == nil { t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) } } } func BenchmarkDecode(b *testing.B) { benchmarks := []struct { name string numberStr string }{ {"Zero", "0"}, {"Small", "12345"}, {"Medium", "12345.12345"}, {"Large", "123457890.1234567890"}, {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, } for _, bm := range benchmarks { src := &shopspring.Numeric{} err := src.Set(bm.numberStr) require.NoError(b, err) textFormat, err := src.EncodeText(nil, nil) require.NoError(b, err) binaryFormat, err := src.EncodeBinary(nil, nil) require.NoError(b, err) b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { dst := &shopspring.Numeric{} for i := 0; i < b.N; i++ { err := dst.DecodeText(nil, textFormat) if err != nil { b.Fatal(err) } } }) b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { dst := &shopspring.Numeric{} for i := 0; i < b.N; i++ { err := dst.DecodeBinary(nil, binaryFormat) if err != nil { b.Fatal(err) } } }) } }