Add NullDecimal to shopspring-numeric

The shopspring/decimal package provides a NullDecimal struct intended
for use with nullable SQL NUMERICs and numbers. It has Scanner and
Valuer implementations already, but adding it to this package allows
it to be used with the binary encoding as well.

The implementation is very straightforward, but the tests have been made
slightly more complicated. The previous version wasn't testing the
decimal.Decimal cases, and this change adds those as well as new
NullDecimal cases. I've added some logic to the test harness to catch
these as you need to use the Equals method to properly compare Decimals.
non-blocking
Eli Treuherz 2021-08-02 09:26:24 +01:00 committed by Jack Christensen
parent 6bda09691d
commit db84905b7f
2 changed files with 47 additions and 3 deletions

View File

@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error {
switch value := src.(type) {
case decimal.Decimal:
*dst = Numeric{Decimal: value, Status: pgtype.Present}
case decimal.NullDecimal:
if value.Valid {
*dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present}
} else {
*dst = Numeric{Status: pgtype.Null}
}
case float32:
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
case float64:
@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error {
switch v := dst.(type) {
case *decimal.Decimal:
*v = src.Decimal
case *decimal.NullDecimal:
(*v).Valid = true
(*v).Decimal = src.Decimal
case *float32:
f, _ := src.Decimal.Float64()
*v = float32(f)
@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error {
return fmt.Errorf("unable to assign to %T", dst)
}
case pgtype.Null:
return pgtype.NullAssignTo(dst)
if v, ok := dst.(*decimal.NullDecimal); ok {
(*v).Valid = false
} else {
return pgtype.NullAssignTo(dst)
}
}
return nil

View File

@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) {
source interface{}
result *shopspring.Numeric
}{
{source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}},
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) {
var f64 float64
var pf32 *float32
var pf64 *float64
var d decimal.Decimal
var nd decimal.NullDecimal
simpleTests := []struct {
src *shopspring.Numeric
@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) {
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}},
}
for i, tt := range simpleTests {
// Zero out the destination variable
reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem()))
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
// Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this
// we end up checking reference equality on the *big.Int they contain.
switch dst := tt.dst.(type) {
case *decimal.Decimal:
if !dst.Equal(tt.expected.(decimal.Decimal)) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d)
}
case *decimal.NullDecimal:
expected := tt.expected.(decimal.NullDecimal)
if dst.Valid != expected.Valid {
t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid)
}
if !dst.Decimal.Equal(expected.Decimal) {
t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal)
}
default:
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}