Merge pull request #447 from stretchr/issue-442

Fix race condition on mock package's Called
pull/356/head
Ernesto Jiménez 2017-05-28 14:51:04 +01:00 committed by GitHub
commit b8c9b4ef3d
1 changed files with 21 additions and 22 deletions

View File

@ -214,8 +214,6 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
// */
func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
m.mutex.Lock()
defer m.mutex.Unlock()
for i, call := range m.ExpectedCalls {
if call.Method == method && call.Repeatability > -1 {
@ -288,6 +286,7 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
parts := strings.Split(functionPath, ".")
functionName := parts[len(parts)-1]
m.mutex.Lock()
found, call := m.findExpectedCall(functionName, arguments...)
if found < 0 {
@ -299,31 +298,29 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
// c) the developer has forgotten to add an accompanying On...Return pair.
closestFound, closestCall := m.findClosestCall(functionName, arguments...)
m.mutex.Unlock()
if closestFound {
panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\n", callString(functionName, arguments, true), callString(functionName, closestCall.Arguments, true), diffArguments(arguments, closestCall.Arguments)))
} else {
panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo()))
}
} else {
m.mutex.Lock()
switch {
case call.Repeatability == 1:
call.Repeatability = -1
call.totalCalls++
}
case call.Repeatability > 1:
call.Repeatability--
call.totalCalls++
switch {
case call.Repeatability == 1:
call.Repeatability = -1
call.totalCalls++
case call.Repeatability == 0:
call.totalCalls++
}
m.mutex.Unlock()
case call.Repeatability > 1:
call.Repeatability--
call.totalCalls++
case call.Repeatability == 0:
call.totalCalls++
}
// add the call
m.mutex.Lock()
m.Calls = append(m.Calls, *newCall(m, functionName, arguments...))
m.mutex.Unlock()
@ -368,6 +365,8 @@ func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
// AssertExpectations asserts that everything specified with On and Return was
// in fact called as expected. Calls may have occurred in any order.
func (m *Mock) AssertExpectations(t TestingT) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
var somethingMissing bool
var failedExpectations int
@ -379,14 +378,12 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
failedExpectations++
t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
} else {
m.mutex.Lock()
if expectedCall.Repeatability > 0 {
somethingMissing = true
failedExpectations++
} else {
t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
}
m.mutex.Unlock()
}
}
@ -399,6 +396,8 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
// AssertNumberOfCalls asserts that the method was called expectedCalls times.
func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
var actualCalls int
for _, call := range m.calls() {
if call.Method == methodName {
@ -411,6 +410,8 @@ func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls
// AssertCalled asserts that the method was called.
// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) {
t.Logf("%v", m.expectedCalls())
return false
@ -421,6 +422,8 @@ func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interfac
// AssertNotCalled asserts that the method was not called.
// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) {
t.Logf("%v", m.expectedCalls())
return false
@ -446,14 +449,10 @@ func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
}
func (m *Mock) expectedCalls() []*Call {
m.mutex.Lock()
defer m.mutex.Unlock()
return append([]*Call{}, m.ExpectedCalls...)
}
func (m *Mock) calls() []Call {
m.mutex.Lock()
defer m.mutex.Unlock()
return append([]Call{}, m.Calls...)
}