From ce654ad1e1ef5bde386d2873a19a39092d5ea6f4 Mon Sep 17 00:00:00 2001 From: Wei Congrui Date: Fri, 18 Aug 2017 15:20:39 +0800 Subject: [PATCH] Fix numeric EncodeBinary bug --- pgtype/numeric.go | 10 +++++++--- pgtype/numeric_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index fded6359..fb63df75 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -16,6 +16,7 @@ import ( const nbase = 10000 var big0 *big.Int = big.NewInt(0) +var big1 *big.Int = big.NewInt(1) var big10 *big.Int = big.NewInt(10) var big100 *big.Int = big.NewInt(100) var big1000 *big.Int = big.NewInt(1000) @@ -507,6 +508,7 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { divisor := &big.Int{} divisor.Exp(big10, big.NewInt(int64(-exp)), nil) wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) } else { wholePart = absInt } @@ -518,9 +520,11 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { wholeDigits = append(wholeDigits, int16(remainder.Int64())) } - for fracPart.Cmp(big0) != 0 { - fracPart.DivMod(fracPart, bigNBase, remainder) - fracDigits = append(fracDigits, int16(remainder.Int64())) + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } } buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 5f3a3416..9d7d83d6 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -317,3 +317,39 @@ func TestNumericAssignTo(t *testing.T) { } } } + +func TestNumericEncodeDecodeBinary(t *testing.T) { + ci := pgtype.NewConnInfo() + tests := []interface{}{ + 123, + 0.000012345, + 1.00002345, + } + + for i, tt := range tests { + toString := func(n *pgtype.Numeric) string { + ci := pgtype.NewConnInfo() + text, err := n.EncodeText(ci, nil) + if err != nil { + t.Errorf("%d: %v", i, err) + } + return string(text) + } + numeric := &pgtype.Numeric{} + numeric.Set(tt) + + encoded, err := numeric.EncodeBinary(ci, nil) + if err != nil { + t.Errorf("%d: %v", i, err) + } + decoded := &pgtype.Numeric{} + decoded.DecodeBinary(ci, encoded) + + text0 := toString(numeric) + text1 := toString(decoded) + + if text0 != text1 { + t.Errorf("%d: expected %v to equal to %v, but doesn't", i, text0, text1) + } + } +}