diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 426b6782..b9827b63 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -109,6 +109,33 @@ func (n Numeric) Float64Value() (Float8, error) { return Float8{Float64: f, Valid: true}, nil } +func (n *Numeric) ScanInt64(v Int8) error { + if !v.Valid { + *n = Numeric{} + return nil + } + + *n = Numeric{Int: big.NewInt(v.Int64), Valid: true} + return nil +} + +func (n Numeric) Int64Value() (Int8, error) { + if !n.Valid { + return Int8{}, nil + } + + bi, err := n.toBigInt() + if err != nil { + return Int8{}, err + } + + if !bi.IsInt64() { + return Int8{}, fmt.Errorf("cannot convert %v to int64", n) + } + + return Int8{Int64: bi.Int64(), Valid: true}, nil +} + func (n *Numeric) toBigInt() (*big.Int, error) { if n.Exp == 0 { return n.Int, nil @@ -450,18 +477,15 @@ func (encodePlanNumericCodecTextFloat64Valuer) Encode(value interface{}, buf []b } if math.IsNaN(n.Float64) { - return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf) + buf = append(buf, "NaN"...) } else if math.IsInf(n.Float64, 1) { - return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf) + buf = append(buf, "Infinity"...) } else if math.IsInf(n.Float64, -1) { - return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf) + buf = append(buf, "-Infinity"...) + } else { + buf = append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...) } - num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64)) - if err != nil { - return nil, err - } - - return encodeNumericText(Numeric{Int: num, Exp: exp, Valid: true}, buf) + return buf, nil } type encodePlanNumericCodecTextInt64Valuer struct{} @@ -476,7 +500,8 @@ func (encodePlanNumericCodecTextInt64Valuer) Encode(value interface{}, buf []byt return nil, nil } - return encodeNumericText(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf) + buf = append(buf, strconv.FormatInt(n.Int64, 10)...) + return buf, nil } func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { @@ -495,9 +520,20 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { return buf, nil } - buf = append(buf, n.Int.String()...) - buf = append(buf, 'e') - buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + digits := n.Int.String() + if n.Exp >= 0 { + buf = append(buf, digits...) + if n.Exp > 0 { + for i := int32(0); i < n.Exp; i++ { + buf = append(buf, '0') + } + } + } else { + buf = append(buf, digits...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...) + } + return buf, nil } diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 8be8ce55..3c37ae18 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -113,6 +113,12 @@ func TestNumericCodec(t *testing.T) { {pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, {nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})}, }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ + {mustParseNumeric(t, "-1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))}, + {mustParseNumeric(t, "0"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))}, + {mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))}, + }) } func TestNumericCodecInfinity(t *testing.T) {