diff --git a/assert/assertions.go b/assert/assertions.go index 10c1992..fb2707f 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -173,8 +173,13 @@ func NotNil(t *testing.T, object interface{}, msgAndArgs ...interface{}) bool { if object == nil { success = false - } else if reflect.ValueOf(object).IsNil() { - success = false + } + + val := reflect.ValueOf(object) + if val.CanAddr() { + if reflect.ValueOf(object).IsNil() { + success = false + } } if !success { @@ -201,8 +206,13 @@ func Nil(t *testing.T, object interface{}, msgAndArgs ...interface{}) bool { if object == nil { return true - } else if reflect.ValueOf(object).IsNil() { - return true + } + + val := reflect.ValueOf(object) + if val.CanAddr() { + if reflect.ValueOf(object).IsNil() { + return true + } } if len(message) > 0 { diff --git a/assert/assertions_test.go b/assert/assertions_test.go index c965173..7396dd3 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -5,6 +5,10 @@ import ( "testing" ) +type CustomError struct{} + +func (c CustomError) Error() string { return "Error" } + // AssertionTesterInterface defines an interface to be used for testing assertion methods type AssertionTesterInterface interface { TestMethod() @@ -115,6 +119,54 @@ func TestNil(t *testing.T) { } +func TestNil_ValueType(t *testing.T) { + + f := func() AssertionTesterConformingObject { + return AssertionTesterConformingObject{} + } + + mockT := new(testing.T) + obj := f() + False(t, Nil(mockT, obj)) + +} + +func TestNil_ValueType_IsNil(t *testing.T) { + + f := func() AssertionTesterInterface { + return nil + } + + mockT := new(testing.T) + obj := f() + True(t, Nil(mockT, obj)) + +} + +func TestNil_ValueError(t *testing.T) { + + f := func() error { + return CustomError{} + } + + mockT := new(testing.T) + obj := f() + False(t, Nil(mockT, obj)) + +} + +func TestNotNil_ValueError(t *testing.T) { + + f := func() error { + return CustomError{} + } + + mockT := new(testing.T) + obj := f() + True(t, NotNil(mockT, obj)) + +} + func TestTrue(t *testing.T) { mockT := new(testing.T) @@ -333,3 +385,15 @@ func TestNotEmpty(t *testing.T) { True(t, NotEmpty(mockT, true), "True value is not empty") } + +func TestEmpty_Value(t *testing.T) { + + f := func() error { + return CustomError{} + } + + mockT := new(testing.T) + obj := f() + False(t, Empty(mockT, obj)) + +}