From 4c4d0118a66fc95ca42fceb569ed155c953eebbb Mon Sep 17 00:00:00 2001 From: Arjun Mahishi Date: Sat, 10 Feb 2024 12:59:27 +0530 Subject: [PATCH] assert: Fix EqualValues to handle overflow/underflow The underlying function ObjectsAreEqualValues did not handle overflow/underflow of values while converting one type to another for comparison. For example: EqualValues(t, int(270), int8(14)) would return true, even though the values are not equal. Because, when you convert int(270) to int8, it overflows and becomes 14 (270 % 256 = 14) This commit fixes that by making sure that the conversion always happens from the smaller type to the larger type, and then comparing the values. Additionally, this commit also seperates out the test cases of ObjectsAreEqualValues from TestObjectsAreEqual. Fixes #1462 --- assert/assertions.go | 37 +++++++++++++++++++++++++------ assert/assertions_test.go | 46 +++++++++++++++++++++++++++------------ 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/assert/assertions.go b/assert/assertions.go index 909daa5..b3b309b 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -165,17 +165,40 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { return true } - actualType := reflect.TypeOf(actual) - if actualType == nil { + expectedValue := reflect.ValueOf(expected) + actualValue := reflect.ValueOf(actual) + if !expectedValue.IsValid() || !actualValue.IsValid() { return false } - expectedValue := reflect.ValueOf(expected) - if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { - // Attempt comparison after type conversion - return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + + expectedType := expectedValue.Type() + actualType := actualValue.Type() + if !expectedType.ConvertibleTo(actualType) { + return false } - return false + if !isNumericType(expectedType) || !isNumericType(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual( + expectedValue.Convert(actualType).Interface(), actual, + ) + } + + // If BOTH values are numeric, there are chances of false positives due + // to overflow or underflow. So, we need to make sure to always convert + // the smaller type to a larger type before comparing. + if expectedType.Size() >= actualType.Size() { + return actualValue.Convert(expectedType).Interface() == expected + } + + return expectedValue.Convert(actualType).Interface() == actual +} + +// isNumericType returns true if the type is one of: +// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, +// float32, float64, complex64, complex128 +func isNumericType(t reflect.Type) bool { + return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128 } /* CallerInfo is necessary because the assert functions use the testing object diff --git a/assert/assertions_test.go b/assert/assertions_test.go index d65dd02..a06d6e9 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -135,24 +135,42 @@ func TestObjectsAreEqual(t *testing.T) { }) } +} - // Cases where type differ but values are equal - if !ObjectsAreEqualValues(uint32(10), int32(10)) { - t.Error("ObjectsAreEqualValues should return true") - } - if ObjectsAreEqualValues(0, nil) { - t.Fail() - } - if ObjectsAreEqualValues(nil, 0) { - t.Fail() +func TestObjectsAreEqualValues(t *testing.T) { + now := time.Now() + + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + {uint32(10), int32(10), true}, + {0, nil, false}, + {nil, 0, false}, + {now, now.In(time.Local), true}, // should be time zone independent + {int(270), int8(14), false}, // should handle overflow/underflow + {int8(14), int(270), false}, + {[]int{270, 270}, []int8{14, 14}, false}, + {complex128(1e+100 + 1e+100i), complex64(complex(math.Inf(0), math.Inf(0))), false}, + {complex64(complex(math.Inf(0), math.Inf(0))), complex128(1e+100 + 1e+100i), false}, + {complex128(1e+100 + 1e+100i), 270, false}, + {270, complex128(1e+100 + 1e+100i), false}, + {complex128(1e+100 + 1e+100i), 3.14, false}, + {3.14, complex128(1e+100 + 1e+100i), false}, + {complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true}, + {complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true}, } - tm := time.Now() - tz := tm.In(time.Local) - if !ObjectsAreEqualValues(tm, tz) { - t.Error("ObjectsAreEqualValues should return true for time.Time objects with different time zones") - } + for _, c := range cases { + t.Run(fmt.Sprintf("ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := ObjectsAreEqualValues(c.expected, c.actual) + if res != c.result { + t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result) + } + }) + } } type Nested struct {