mirror of https://github.com/jackc/pgx.git
Add NullDecimal to shopspring-numeric
The shopspring/decimal package provides a NullDecimal struct intended for use with nullable SQL NUMERICs and numbers. It has Scanner and Valuer implementations already, but adding it to this package allows it to be used with the binary encoding as well. The implementation is very straightforward, but the tests have been made slightly more complicated. The previous version wasn't testing the decimal.Decimal cases, and this change adds those as well as new NullDecimal cases. I've added some logic to the test harness to catch these as you need to use the Equals method to properly compare Decimals.non-blocking
parent
6bda09691d
commit
db84905b7f
|
@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error {
|
|||
switch value := src.(type) {
|
||||
case decimal.Decimal:
|
||||
*dst = Numeric{Decimal: value, Status: pgtype.Present}
|
||||
case decimal.NullDecimal:
|
||||
if value.Valid {
|
||||
*dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present}
|
||||
} else {
|
||||
*dst = Numeric{Status: pgtype.Null}
|
||||
}
|
||||
case float32:
|
||||
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
|
||||
case float64:
|
||||
|
@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error {
|
|||
switch v := dst.(type) {
|
||||
case *decimal.Decimal:
|
||||
*v = src.Decimal
|
||||
case *decimal.NullDecimal:
|
||||
(*v).Valid = true
|
||||
(*v).Decimal = src.Decimal
|
||||
case *float32:
|
||||
f, _ := src.Decimal.Float64()
|
||||
*v = float32(f)
|
||||
|
@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error {
|
|||
return fmt.Errorf("unable to assign to %T", dst)
|
||||
}
|
||||
case pgtype.Null:
|
||||
return pgtype.NullAssignTo(dst)
|
||||
if v, ok := dst.(*decimal.NullDecimal); ok {
|
||||
(*v).Valid = false
|
||||
} else {
|
||||
return pgtype.NullAssignTo(dst)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) {
|
|||
source interface{}
|
||||
result *shopspring.Numeric
|
||||
}{
|
||||
{source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}},
|
||||
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||
|
@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) {
|
|||
var f64 float64
|
||||
var pf32 *float32
|
||||
var pf64 *float64
|
||||
var d decimal.Decimal
|
||||
var nd decimal.NullDecimal
|
||||
|
||||
simpleTests := []struct {
|
||||
src *shopspring.Numeric
|
||||
|
@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) {
|
|||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
|
||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
|
||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)},
|
||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}},
|
||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
// Zero out the destination variable
|
||||
reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem()))
|
||||
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
// Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this
|
||||
// we end up checking reference equality on the *big.Int they contain.
|
||||
switch dst := tt.dst.(type) {
|
||||
case *decimal.Decimal:
|
||||
if !dst.Equal(tt.expected.(decimal.Decimal)) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d)
|
||||
}
|
||||
case *decimal.NullDecimal:
|
||||
expected := tt.expected.(decimal.NullDecimal)
|
||||
|
||||
if dst.Valid != expected.Valid {
|
||||
t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid)
|
||||
}
|
||||
if !dst.Decimal.Equal(expected.Decimal) {
|
||||
t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal)
|
||||
}
|
||||
default:
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue