From bd79c01e742b23500fc1d7c40880506473831597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Jime=CC=81nez?= Date: Sun, 28 May 2017 14:43:48 +0100 Subject: [PATCH] Fix race condition on mock package's Called Closes #442 --- mock/mock.go | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index f04914c..78fe1d4 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -214,8 +214,6 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Call { // */ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { - m.mutex.Lock() - defer m.mutex.Unlock() for i, call := range m.ExpectedCalls { if call.Method == method && call.Repeatability > -1 { @@ -288,6 +286,7 @@ func (m *Mock) Called(arguments ...interface{}) Arguments { parts := strings.Split(functionPath, ".") functionName := parts[len(parts)-1] + m.mutex.Lock() found, call := m.findExpectedCall(functionName, arguments...) if found < 0 { @@ -299,31 +298,29 @@ func (m *Mock) Called(arguments ...interface{}) Arguments { // c) the developer has forgotten to add an accompanying On...Return pair. closestFound, closestCall := m.findClosestCall(functionName, arguments...) + m.mutex.Unlock() if closestFound { panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\n", callString(functionName, arguments, true), callString(functionName, closestCall.Arguments, true), diffArguments(arguments, closestCall.Arguments))) } else { panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo())) } - } else { - m.mutex.Lock() - switch { - case call.Repeatability == 1: - call.Repeatability = -1 - call.totalCalls++ + } - case call.Repeatability > 1: - call.Repeatability-- - call.totalCalls++ + switch { + case call.Repeatability == 1: + call.Repeatability = -1 + call.totalCalls++ - case call.Repeatability == 0: - call.totalCalls++ - } - m.mutex.Unlock() + case call.Repeatability > 1: + call.Repeatability-- + call.totalCalls++ + + case call.Repeatability == 0: + call.totalCalls++ } // add the call - m.mutex.Lock() m.Calls = append(m.Calls, *newCall(m, functionName, arguments...)) m.mutex.Unlock() @@ -368,6 +365,8 @@ func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool { // AssertExpectations asserts that everything specified with On and Return was // in fact called as expected. Calls may have occurred in any order. func (m *Mock) AssertExpectations(t TestingT) bool { + m.mutex.Lock() + defer m.mutex.Unlock() var somethingMissing bool var failedExpectations int @@ -379,14 +378,12 @@ func (m *Mock) AssertExpectations(t TestingT) bool { failedExpectations++ t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } else { - m.mutex.Lock() if expectedCall.Repeatability > 0 { somethingMissing = true failedExpectations++ } else { t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } - m.mutex.Unlock() } } @@ -399,6 +396,8 @@ func (m *Mock) AssertExpectations(t TestingT) bool { // AssertNumberOfCalls asserts that the method was called expectedCalls times. func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool { + m.mutex.Lock() + defer m.mutex.Unlock() var actualCalls int for _, call := range m.calls() { if call.Method == methodName { @@ -411,6 +410,8 @@ func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls // AssertCalled asserts that the method was called. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool { + m.mutex.Lock() + defer m.mutex.Unlock() if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) { t.Logf("%v", m.expectedCalls()) return false @@ -421,6 +422,8 @@ func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interfac // AssertNotCalled asserts that the method was not called. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool { + m.mutex.Lock() + defer m.mutex.Unlock() if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) { t.Logf("%v", m.expectedCalls()) return false @@ -446,14 +449,10 @@ func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { } func (m *Mock) expectedCalls() []*Call { - m.mutex.Lock() - defer m.mutex.Unlock() return append([]*Call{}, m.ExpectedCalls...) } func (m *Mock) calls() []Call { - m.mutex.Lock() - defer m.mutex.Unlock() return append([]Call{}, m.Calls...) }