diff --git a/assert/assertions.go b/assert/assertions.go index dc412ff..1dd7fe1 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -368,13 +368,7 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b return Fail(t, "Invalid operation: both arguments must be pointers", msgAndArgs...) } - expectedType, actualType := reflect.TypeOf(expected), reflect.TypeOf(actual) - if expectedType != actualType { - return Fail(t, fmt.Sprintf("Pointer expected to be of type %v, but was %v", - expectedType, actualType), msgAndArgs...) - } - - if expected != actual { + if !sameComparator(expected, actual) { return Fail(t, fmt.Sprintf("Not same: \n"+ "expected: %p %#v\n"+ "actual : %p %#v", expected, expected, actual, actual), msgAndArgs...) @@ -399,16 +393,25 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} return Fail(t, "Invalid operation: both arguments must be pointers", msgAndArgs...) } - expectedType, actualType := reflect.TypeOf(expected), reflect.TypeOf(actual) - if expectedType != actualType { - return true - } - - if expected == actual { + if sameComparator(expected, actual) { return Fail(t, fmt.Sprintf( "Expected and actual point to the same object: %p %#v", expected, expected), msgAndArgs...) } + return true +} + +// sameComparator compares two generic interface objects and returns whether +// they are of the same type and value +func sameComparator(first, second interface{}) bool { + firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) + if firstType != secondType { + return false + } + + if first != second { + return false + } return true } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 37bd186..7053997 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -263,6 +263,59 @@ func TestNotSame(t *testing.T) { } } +func Test_sameComparator(t *testing.T) { + type args struct { + first interface{} + second interface{} + } + tests := []struct { + name string + args args + assertion BoolAssertionFunc + }{ + { + name: "1 != 2", + args: args{first: 1, second: 2}, + assertion: False, + }, + { + name: "1 == 1", + args: args{first: 1, second: 1}, + assertion: True, + }, + { + name: "int(1) != float32(1)", + args: args{first: int(1), second: float32(1)}, + assertion: False, + }, + { + name: "true == true", + args: args{first: true, second: true}, + assertion: True, + }, + { + name: "false != true", + args: args{first: false, second: true}, + assertion: False, + }, + { + name: "array != slice", + args: args{first: [2]int{1, 2}, second: []int{1, 2}}, + assertion: False, + }, + { + name: "array == array", + args: args{first: [2]int{1, 2}, second: [2]int{1, 2}}, + assertion: True, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, sameComparator(tt.args.first, tt.args.second)) + }) + } +} + // bufferT implements TestingT. Its implementation of Errorf writes the output that would be produced by // testing.T.Errorf to an internal bytes.Buffer. type bufferT struct {