diff --git a/mock/mock.go b/mock/mock.go index bead86d..b1bfc3b 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -27,6 +27,10 @@ type Call struct { // Holds the arguments that should be returned when // this method is called. ReturnArguments Arguments + + // The number of times to return the return arguments when setting + // expectations. 0 means to always return the value. + Repeatability int } // Mock is the workhorse used to track activity on another object. @@ -83,26 +87,48 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Mock { // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2) func (m *Mock) Return(returnArguments ...interface{}) *Mock { - m.ExpectedCalls = append(m.ExpectedCalls, Call{m.onMethodName, m.onMethodArguments, returnArguments}) + m.ExpectedCalls = append(m.ExpectedCalls, Call{m.onMethodName, m.onMethodArguments, returnArguments, 0}) return m } +// Once indicates that that the mock should only return the value once. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() +func (m* Mock) Once() { + m.ExpectedCalls[len(m.ExpectedCalls) - 1].Repeatability = 1 +} + +// Twice indicates that that the mock should only return the value twice. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() +func (m* Mock) Twice() { + m.ExpectedCalls[len(m.ExpectedCalls) - 1].Repeatability = 2 +} + +// Times indicates that that the mock should only return the indicated number +// of times. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) +func (m* Mock) Times(i int) { + m.ExpectedCalls[len(m.ExpectedCalls) - 1].Repeatability = i +} + /* Recording and responding to activity */ -func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (bool, *Call) { - for _, call := range m.ExpectedCalls { - if call.Method == method { +func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { + for i, call := range m.ExpectedCalls { + if call.Method == method && call.Repeatability > -1 { _, diffCount := call.Arguments.Diff(arguments) if diffCount == 0 { - return true, &call + return i, &call } } } - return false, nil + return -1, nil } func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) { @@ -159,8 +185,8 @@ func (m *Mock) Called(arguments ...interface{}) Arguments { found, call := m.findExpectedCall(functionName, arguments...) - if !found { - + switch { + case found < 0: // we have to fail here - because we don't know what to do // as the return arguments. This is because: // @@ -175,11 +201,16 @@ func (m *Mock) Called(arguments ...interface{}) Arguments { } else { panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo())) } - + case call.Repeatability == 1: + call.Repeatability = -1 + m.ExpectedCalls[found] = *call + case call.Repeatability > 1: + call.Repeatability -= 1 + m.ExpectedCalls[found] = *call } // add the call - m.Calls = append(m.Calls, Call{functionName, arguments, make([]interface{}, 0)}) + m.Calls = append(m.Calls, Call{functionName, arguments, make([]interface{}, 0), 0}) return call.ReturnArguments @@ -211,11 +242,15 @@ func (m *Mock) AssertExpectations(t *testing.T) bool { // iterate through each expectation for _, expectedCall := range m.ExpectedCalls { - if !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) { + switch { + case !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments): somethingMissing = true failedExpectations++ t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) - } else { + case expectedCall.Repeatability > 0: + somethingMissing = true + failedExpectations++ + default: t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } } diff --git a/mock/mock_test.go b/mock/mock_test.go index f50b3b1..1ed1266 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -22,7 +22,7 @@ type TestExampleImplementation struct { func (i *TestExampleImplementation) TheExampleMethod(a, b, c int) (int, error) { args := i.Mock.Called(a, b, c) - return args.Int(0), args.Error(1) + return args.Int(0), errors.New("Whoops") } func (i *TestExampleImplementation) TheExampleMethod2(yesorno bool) { @@ -87,6 +87,79 @@ func Test_Mock_Return(t *testing.T) { assert.Equal(t, 1, call.ReturnArguments[0]) assert.Equal(t, "two", call.ReturnArguments[1]) assert.Equal(t, true, call.ReturnArguments[2]) + assert.Equal(t, 0, call.Repeatability) + + } + +} + +func Test_Mock_Return_Once(t *testing.T) { + + // make a test impl object + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("TheExampleMethod", "A", "B", true).Return(1, "two", true).Once() + + // ensure the call was created + if assert.Equal(t, 1, len(mockedService.Mock.ExpectedCalls)) { + call := mockedService.Mock.ExpectedCalls[0] + + assert.Equal(t, "TheExampleMethod", call.Method) + assert.Equal(t, "A", call.Arguments[0]) + assert.Equal(t, "B", call.Arguments[1]) + assert.Equal(t, true, call.Arguments[2]) + assert.Equal(t, 1, call.ReturnArguments[0]) + assert.Equal(t, "two", call.ReturnArguments[1]) + assert.Equal(t, true, call.ReturnArguments[2]) + assert.Equal(t, 1, call.Repeatability) + + } + +} + +func Test_Mock_Return_Twice(t *testing.T) { + + // make a test impl object + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("TheExampleMethod", "A", "B", true).Return(1, "two", true).Twice() + + // ensure the call was created + if assert.Equal(t, 1, len(mockedService.Mock.ExpectedCalls)) { + call := mockedService.Mock.ExpectedCalls[0] + + assert.Equal(t, "TheExampleMethod", call.Method) + assert.Equal(t, "A", call.Arguments[0]) + assert.Equal(t, "B", call.Arguments[1]) + assert.Equal(t, true, call.Arguments[2]) + assert.Equal(t, 1, call.ReturnArguments[0]) + assert.Equal(t, "two", call.ReturnArguments[1]) + assert.Equal(t, true, call.ReturnArguments[2]) + assert.Equal(t, 2, call.Repeatability) + + } + +} + +func Test_Mock_Return_Times(t *testing.T) { + + // make a test impl object + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("TheExampleMethod", "A", "B", true).Return(1, "two", true).Times(5) + + // ensure the call was created + if assert.Equal(t, 1, len(mockedService.Mock.ExpectedCalls)) { + call := mockedService.Mock.ExpectedCalls[0] + + assert.Equal(t, "TheExampleMethod", call.Method) + assert.Equal(t, "A", call.Arguments[0]) + assert.Equal(t, "B", call.Arguments[1]) + assert.Equal(t, true, call.Arguments[2]) + assert.Equal(t, 1, call.ReturnArguments[0]) + assert.Equal(t, "two", call.ReturnArguments[1]) + assert.Equal(t, true, call.ReturnArguments[2]) + assert.Equal(t, 5, call.Repeatability) } @@ -122,7 +195,40 @@ func Test_Mock_findExpectedCall(t *testing.T) { f, c := m.findExpectedCall("Two", 3) - if assert.True(t, f) { + if assert.Equal(t, 2, f) { + if assert.NotNil(t, c) { + assert.Equal(t, "Two", c.Method) + assert.Equal(t, 3, c.Arguments[0]) + assert.Equal(t, "three", c.ReturnArguments[0]) + } + } + +} + +func Test_Mock_findExpectedCall_For_Unknown_Method(t *testing.T) { + + m := new(Mock) + m.On("One", 1).Return("one") + m.On("Two", 2).Return("two") + m.On("Two", 3).Return("three") + + f, _ := m.findExpectedCall("Two") + + assert.Equal(t, -1, f) + +} + +func Test_Mock_findExpectedCall_Respects_Repeatability(t *testing.T) { + + m := new(Mock) + m.On("One", 1).Return("one") + m.On("Two", 2).Return("two").Once() + m.On("Two", 3).Return("three").Twice() + m.On("Two", 3).Return("three").Times(8) + + f, c := m.findExpectedCall("Two", 3) + + if assert.Equal(t, 2, f) { if assert.NotNil(t, c) { assert.Equal(t, "Two", c.Method) assert.Equal(t, 3, c.Arguments[0]) @@ -161,6 +267,58 @@ func Test_Mock_Called(t *testing.T) { } +func Test_Mock_Called_For_Bounded_Repeatability(t *testing.T) { + + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("Test_Mock_Called_For_Bounded_Repeatability", 1, 2, 3).Return(5, "6", true).Once() + mockedService.Mock.On("Test_Mock_Called_For_Bounded_Repeatability", 1, 2, 3).Return(-1, "hi", false) + + returnArguments1 := mockedService.Mock.Called(1, 2, 3) + returnArguments2 := mockedService.Mock.Called(1, 2, 3) + + if assert.Equal(t, 2, len(mockedService.Mock.Calls)) { + assert.Equal(t, "Test_Mock_Called_For_Bounded_Repeatability", mockedService.Mock.Calls[0].Method) + assert.Equal(t, 1, mockedService.Mock.Calls[0].Arguments[0]) + assert.Equal(t, 2, mockedService.Mock.Calls[0].Arguments[1]) + assert.Equal(t, 3, mockedService.Mock.Calls[0].Arguments[2]) + + assert.Equal(t, "Test_Mock_Called_For_Bounded_Repeatability", mockedService.Mock.Calls[1].Method) + assert.Equal(t, 1, mockedService.Mock.Calls[1].Arguments[0]) + assert.Equal(t, 2, mockedService.Mock.Calls[1].Arguments[1]) + assert.Equal(t, 3, mockedService.Mock.Calls[1].Arguments[2]) + } + + if assert.Equal(t, 3, len(returnArguments1)) { + assert.Equal(t, 5, returnArguments1[0]) + assert.Equal(t, "6", returnArguments1[1]) + assert.Equal(t, true, returnArguments1[2]) + } + + if assert.Equal(t, 3, len(returnArguments2)) { + assert.Equal(t, -1, returnArguments2[0]) + assert.Equal(t, "hi", returnArguments2[1]) + assert.Equal(t, false, returnArguments2[2]) + } + +} + +func Test_Mock_Called_For_SetTime_Expectation(t *testing.T) { + + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("TheExampleMethod", 1, 2, 3).Return(5, "6", true).Times(4) + + mockedService.TheExampleMethod(1, 2, 3) + mockedService.TheExampleMethod(1, 2, 3) + mockedService.TheExampleMethod(1, 2, 3) + mockedService.TheExampleMethod(1, 2, 3) + assert.Panics(t, func() { + mockedService.TheExampleMethod(1, 2, 3) + }) + +} + func Test_Mock_Called_Unexpected(t *testing.T) { var mockedService *TestExampleImplementation = new(TestExampleImplementation) @@ -225,6 +383,27 @@ func Test_Mock_AssertExpectations(t *testing.T) { } +func Test_Mock_AssertExpectations_With_Repeatability(t *testing.T) { + + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("Test_Mock_AssertExpectations_With_Repeatability", 1, 2, 3).Return(5, 6, 7).Twice() + + tt := new(testing.T) + assert.False(t, mockedService.AssertExpectations(tt)) + + // make the call now + mockedService.Mock.Called(1, 2, 3) + + assert.False(t, mockedService.AssertExpectations(tt)) + + mockedService.Mock.Called(1, 2, 3) + + // now assert expectations + assert.True(t, mockedService.AssertExpectations(tt)) + +} + func Test_Mock_AssertNumberOfCalls(t *testing.T) { var mockedService *TestExampleImplementation = new(TestExampleImplementation) @@ -265,6 +444,23 @@ func Test_Mock_AssertCalled_WithArguments(t *testing.T) { } +func Test_Mock_AssertCalled_WithArguments_With_Repeatability(t *testing.T) { + + var mockedService *TestExampleImplementation = new(TestExampleImplementation) + + mockedService.Mock.On("Test_Mock_AssertCalled_WithArguments_With_Repeatability", 1, 2, 3).Return(5, 6, 7).Once() + mockedService.Mock.On("Test_Mock_AssertCalled_WithArguments_With_Repeatability", 2, 3, 4).Return(5, 6, 7).Once() + + mockedService.Mock.Called(1, 2, 3) + mockedService.Mock.Called(2, 3, 4) + + tt := new(testing.T) + assert.True(t, mockedService.AssertCalled(tt, "Test_Mock_AssertCalled_WithArguments_With_Repeatability", 1, 2, 3)) + assert.True(t, mockedService.AssertCalled(tt, "Test_Mock_AssertCalled_WithArguments_With_Repeatability", 2, 3, 4)) + assert.False(t, mockedService.AssertCalled(tt, "Test_Mock_AssertCalled_WithArguments_With_Repeatability", 3, 4, 5)) + +} + func Test_Mock_AssertNotCalled(t *testing.T) { var mockedService *TestExampleImplementation = new(TestExampleImplementation)