diff --git a/assert/assertions.go b/assert/assertions.go index 2457a0d..2ca22eb 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -352,6 +352,19 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) } +// validateEqualArgs checks whether provided arguments can be safely used in the +// Equal/NotEqual functions. +func validateEqualArgs(expected, actual interface{}) error { + if expected == nil && actual == nil { + return nil + } + + if isFunction(expected) || isFunction(actual) { + return errors.New("cannot take func type as argument") + } + return nil +} + // Same asserts that two pointers reference the same object. // // assert.Same(t, ptr1, ptr2) @@ -1526,15 +1539,6 @@ func diff(expected interface{}, actual interface{}) string { return "\n\nDiff:\n" + diff } -// validateEqualArgs checks whether provided arguments can be safely used in the -// Equal/NotEqual functions. -func validateEqualArgs(expected, actual interface{}) error { - if isFunction(expected) || isFunction(actual) { - return errors.New("cannot take func type as argument") - } - return nil -} - func isFunction(arg interface{}) bool { if arg == nil { return false diff --git a/assert/assertions_test.go b/assert/assertions_test.go index f398820..b851116 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -131,6 +131,12 @@ func TestObjectsAreEqual(t *testing.T) { if ObjectsAreEqual(0.1, 0) { t.Error("objectsAreEqual should return false") } + if ObjectsAreEqual(time.Now, time.Now) { + t.Error("objectsAreEqual should return false") + } + if ObjectsAreEqual(func() {}, func() {}) { + t.Error("objectsAreEqual should return false") + } if ObjectsAreEqual(uint32(10), int32(10)) { t.Error("objectsAreEqual should return false") } @@ -515,6 +521,9 @@ func TestNotEqual(t *testing.T) { if NotEqual(mockT, funcA, funcB) { t.Error("NotEqual should return false") } + if NotEqual(mockT, nil, nil) { + t.Error("NotEqual should return false") + } if NotEqual(mockT, "Hello World", "Hello World") { t.Error("NotEqual should return false") @@ -850,8 +859,8 @@ func TestPanicsWithError(t *testing.T) { mockT := new(testing.T) - if !PanicsWithError(mockT, "Panic!", func() { - panic(errors.New("Panic!")) + if !PanicsWithError(mockT, "panic", func() { + panic(errors.New("panic")) }) { t.Error("PanicsWithError should return true") } @@ -862,13 +871,13 @@ func TestPanicsWithError(t *testing.T) { } if PanicsWithError(mockT, "at the disco", func() { - panic(errors.New("Panic!")) + panic(errors.New("panic")) }) { t.Error("PanicsWithError should return false") } if PanicsWithError(mockT, "Panic!", func() { - panic("Panic!") + panic("panic") }) { t.Error("PanicsWithError should return false") } @@ -1922,11 +1931,6 @@ func BenchmarkBytesEqual(b *testing.B) { } } -func TestEqualArgsValidation(t *testing.T) { - err := validateEqualArgs(time.Now, time.Now) - EqualError(t, err, "cannot take func type as argument") -} - func ExampleComparisonAssertionFunc() { t := &testing.T{} // provided by test @@ -2153,3 +2157,17 @@ func TestEventuallyIssue805(t *testing.T) { False(t, Eventually(mockT, condition, time.Millisecond, time.Microsecond)) }) } + +func Test_validateEqualArgs(t *testing.T) { + if validateEqualArgs(func() {}, func() {}) == nil { + t.Error("non-nil functions should error") + } + + if validateEqualArgs(func() {}, func() {}) == nil { + t.Error("non-nil functions should error") + } + + if validateEqualArgs(nil, nil) != nil { + t.Error("nil functions are equal") + } +}