From 0b762c6e268d5deeda378e54fcd161082d290ef6 Mon Sep 17 00:00:00 2001 From: leighhopcroft Date: Wed, 10 Jun 2020 16:59:08 +0100 Subject: [PATCH] updated to use boolean IsNaN field on Numeric --- numeric.go | 34 +++++++++++++++++++++++----------- numeric_test.go | 18 ++++++++++++------ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/numeric.go b/numeric.go index 7ee517be..074c2edc 100644 --- a/numeric.go +++ b/numeric.go @@ -52,6 +52,7 @@ type Numeric struct { Int *big.Int Exp int32 Status Status + IsNaN bool } func (dst *Numeric) Set(src interface{}) error { @@ -70,6 +71,7 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: if math.IsNaN(float64(value)) { + *dst = Numeric{Status: Present, IsNaN: true} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) @@ -79,6 +81,7 @@ func (dst *Numeric) Set(src interface{}) error { *dst = Numeric{Int: num, Exp: exp, Status: Present} case float64: if math.IsNaN(value) { + *dst = Numeric{Status: Present, IsNaN: true} return nil } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) @@ -272,13 +275,6 @@ func (src *Numeric) AssignTo(dst interface{}) error { } case Null: return NullAssignTo(dst) - case Undefined: - switch v := dst.(type) { - case *float32: - *v = float32(math.NaN()) - case *float64: - *v = math.NaN() - } } return nil @@ -309,6 +305,10 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { } func (src *Numeric) toFloat64() (float64, error) { + if src.IsNaN { + return math.NaN(), nil + } + buf := make([]byte, 0, 32) buf = append(buf, src.Int.String()...) @@ -328,8 +328,8 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return nil } - if string(src) == "NaN" { - *dst = Numeric{} + if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details. + *dst = Numeric{Status: Present, IsNaN: true} return nil } @@ -384,7 +384,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - *dst = Numeric{} + *dst = Numeric{Status: Present, IsNaN: true} return nil } @@ -491,7 +491,15 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: - buf = append(buf, []byte("NaN")...) + return nil, errUndefined + } + + if src.IsNaN { + // encode as 'NaN' including single quotes, + // "When writing this value [NaN] as a constant in an SQL command, + // you must put quotes around it, for example UPDATE table SET x = 'NaN'" + // https://www.postgresql.org/docs/9.3/datatype-numeric.html + buf = append(buf, "'NaN'"...) return buf, nil } @@ -506,6 +514,10 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: + return nil, errUndefined + } + + if src.IsNaN { buf = pgio.AppendUint64(buf, pgNumericNaN) return buf, nil } diff --git a/numeric_test.go b/numeric_test.go index 259f397e..4d9c5252 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -210,8 +210,8 @@ func TestNumericSet(t *testing.T) { {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, - {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Undefined}}, - {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Undefined}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, } for i, tt := range successfulTests { @@ -269,8 +269,8 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 - {src: &pgtype.Numeric{}, dst: &f64, expected: math.NaN()}, - {src: &pgtype.Numeric{}, dst: &f32, expected: float32(math.NaN())}, + {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f32, expected: float32(math.NaN())}, } for i, tt := range simpleTests { @@ -282,11 +282,17 @@ func TestNumericAssignTo(t *testing.T) { dst := reflect.ValueOf(tt.dst).Elem().Interface() switch dstTyped := dst.(type) { case float32: - if math.IsNaN(float64(tt.expected.(float32))) && !math.IsNaN(float64(dstTyped)) { + nanExpected := math.IsNaN(float64(tt.expected.(float32))) + if nanExpected && !math.IsNaN(float64(dstTyped)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } case float64: - if math.IsNaN(tt.expected.(float64)) && !math.IsNaN(dstTyped) { + nanExpected := math.IsNaN(tt.expected.(float64)) + if nanExpected && !math.IsNaN(dstTyped) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } default: