Allow nil-function Equal comparisons

pull/867/head
Boyan Soubachov 2019-12-19 16:10:44 +02:00 committed by George Lesica
parent 22d5528225
commit 858f37ff9b
2 changed files with 40 additions and 18 deletions

View File

@ -352,6 +352,19 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
}
// validateEqualArgs checks whether provided arguments can be safely used in the
// Equal/NotEqual functions.
func validateEqualArgs(expected, actual interface{}) error {
if expected == nil && actual == nil {
return nil
}
if isFunction(expected) || isFunction(actual) {
return errors.New("cannot take func type as argument")
}
return nil
}
// Same asserts that two pointers reference the same object.
//
// assert.Same(t, ptr1, ptr2)
@ -1526,15 +1539,6 @@ func diff(expected interface{}, actual interface{}) string {
return "\n\nDiff:\n" + diff
}
// validateEqualArgs checks whether provided arguments can be safely used in the
// Equal/NotEqual functions.
func validateEqualArgs(expected, actual interface{}) error {
if isFunction(expected) || isFunction(actual) {
return errors.New("cannot take func type as argument")
}
return nil
}
func isFunction(arg interface{}) bool {
if arg == nil {
return false

View File

@ -131,6 +131,12 @@ func TestObjectsAreEqual(t *testing.T) {
if ObjectsAreEqual(0.1, 0) {
t.Error("objectsAreEqual should return false")
}
if ObjectsAreEqual(time.Now, time.Now) {
t.Error("objectsAreEqual should return false")
}
if ObjectsAreEqual(func() {}, func() {}) {
t.Error("objectsAreEqual should return false")
}
if ObjectsAreEqual(uint32(10), int32(10)) {
t.Error("objectsAreEqual should return false")
}
@ -515,6 +521,9 @@ func TestNotEqual(t *testing.T) {
if NotEqual(mockT, funcA, funcB) {
t.Error("NotEqual should return false")
}
if NotEqual(mockT, nil, nil) {
t.Error("NotEqual should return false")
}
if NotEqual(mockT, "Hello World", "Hello World") {
t.Error("NotEqual should return false")
@ -850,8 +859,8 @@ func TestPanicsWithError(t *testing.T) {
mockT := new(testing.T)
if !PanicsWithError(mockT, "Panic!", func() {
panic(errors.New("Panic!"))
if !PanicsWithError(mockT, "panic", func() {
panic(errors.New("panic"))
}) {
t.Error("PanicsWithError should return true")
}
@ -862,13 +871,13 @@ func TestPanicsWithError(t *testing.T) {
}
if PanicsWithError(mockT, "at the disco", func() {
panic(errors.New("Panic!"))
panic(errors.New("panic"))
}) {
t.Error("PanicsWithError should return false")
}
if PanicsWithError(mockT, "Panic!", func() {
panic("Panic!")
panic("panic")
}) {
t.Error("PanicsWithError should return false")
}
@ -1922,11 +1931,6 @@ func BenchmarkBytesEqual(b *testing.B) {
}
}
func TestEqualArgsValidation(t *testing.T) {
err := validateEqualArgs(time.Now, time.Now)
EqualError(t, err, "cannot take func type as argument")
}
func ExampleComparisonAssertionFunc() {
t := &testing.T{} // provided by test
@ -2153,3 +2157,17 @@ func TestEventuallyIssue805(t *testing.T) {
False(t, Eventually(mockT, condition, time.Millisecond, time.Microsecond))
})
}
func Test_validateEqualArgs(t *testing.T) {
if validateEqualArgs(func() {}, func() {}) == nil {
t.Error("non-nil functions should error")
}
if validateEqualArgs(func() {}, func() {}) == nil {
t.Error("non-nil functions should error")
}
if validateEqualArgs(nil, nil) != nil {
t.Error("nil functions are equal")
}
}