mock: fix races to m.Calls/m.ExpectedCalls

Fixes races inside the mock package caused by
unsychronized reads/writes of m.Calls/m.ExpectedCalls.
pull/177/head
Sergiusz Urbaniak 2015-06-05 14:46:11 +02:00
parent 3c81d9b268
commit b11fb16915
1 changed files with 48 additions and 23 deletions

View File

@ -2,13 +2,14 @@ package mock
import (
"fmt"
"github.com/stretchr/objx"
"github.com/stretchr/testify/assert"
"reflect"
"runtime"
"strings"
"sync"
"time"
"github.com/stretchr/objx"
"github.com/stretchr/testify/assert"
)
// TestingT is an interface wrapper around *testing.T
@ -123,14 +124,18 @@ func (m *Mock) Return(returnArguments ...interface{}) *Mock {
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
func (m *Mock) Once() {
m.mutex.Lock()
m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = 1
m.mutex.Unlock()
}
// Twice indicates that that the mock should only return the value twice.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
func (m *Mock) Twice() {
m.mutex.Lock()
m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = 2
m.mutex.Unlock()
}
// Times indicates that that the mock should only return the indicated number
@ -138,7 +143,9 @@ func (m *Mock) Twice() {
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
func (m *Mock) Times(i int) {
m.mutex.Lock()
m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = i
m.mutex.Unlock()
}
// WaitUntil sets the channel that will block the mock's return until its closed
@ -146,7 +153,9 @@ func (m *Mock) Times(i int) {
//
// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
func (m *Mock) WaitUntil(w <-chan time.Time) *Mock {
m.mutex.Lock()
m.ExpectedCalls[len(m.ExpectedCalls)-1].WaitFor = w
m.mutex.Unlock()
return m
}
@ -166,7 +175,9 @@ func (m *Mock) After(d time.Duration) *Mock {
// arg["foo"] = "bar"
// })
func (m *Mock) Run(fn func(Arguments)) *Mock {
m.mutex.Lock()
m.ExpectedCalls[len(m.ExpectedCalls)-1].Run = fn
m.mutex.Unlock()
return m
}
@ -175,7 +186,7 @@ func (m *Mock) Run(fn func(Arguments)) *Mock {
*/
func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
for i, call := range m.ExpectedCalls {
for i, call := range m.expectedCalls() {
if call.Method == method && call.Repeatability > -1 {
_, diffCount := call.Arguments.Diff(arguments)
@ -189,11 +200,10 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *
}
func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) {
diffCount := 0
var closestCall *Call = nil
for _, call := range m.ExpectedCalls {
for _, call := range m.expectedCalls() {
if call.Method == method {
_, tempDiffCount := call.Arguments.Diff(arguments)
@ -231,9 +241,6 @@ func callString(method string, arguments Arguments, includeArgumentValues bool)
// appropriate .On .Return() calls)
// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
func (m *Mock) Called(arguments ...interface{}) Arguments {
defer m.mutex.Unlock()
m.mutex.Lock()
// get the calling function's name
pc, _, _, ok := runtime.Caller(1)
if !ok {
@ -245,8 +252,7 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
found, call := m.findExpectedCall(functionName, arguments...)
switch {
case found < 0:
if found < 0 {
// we have to fail here - because we don't know what to do
// as the return arguments. This is because:
//
@ -261,16 +267,23 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
} else {
panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo()))
}
case call.Repeatability == 1:
call.Repeatability = -1
m.ExpectedCalls[found] = *call
case call.Repeatability > 1:
call.Repeatability -= 1
m.ExpectedCalls[found] = *call
} else {
m.mutex.Lock()
switch {
case call.Repeatability == 1:
call.Repeatability = -1
m.ExpectedCalls[found] = *call
case call.Repeatability > 1:
call.Repeatability -= 1
m.ExpectedCalls[found] = *call
}
m.mutex.Unlock()
}
// add the call
m.mutex.Lock()
m.Calls = append(m.Calls, Call{functionName, arguments, make([]interface{}, 0), 0, nil, nil})
m.mutex.Unlock()
// block if specified
if call.WaitFor != nil {
@ -305,12 +318,12 @@ func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
// AssertExpectations asserts that everything specified with On and Return was
// in fact called as expected. Calls may have occurred in any order.
func (m *Mock) AssertExpectations(t TestingT) bool {
var somethingMissing bool = false
var failedExpectations int = 0
// iterate through each expectation
for _, expectedCall := range m.ExpectedCalls {
expectedCalls := m.expectedCalls()
for _, expectedCall := range expectedCalls {
switch {
case !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments):
somethingMissing = true
@ -325,7 +338,7 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
}
if somethingMissing {
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(m.ExpectedCalls)-failedExpectations, len(m.ExpectedCalls), failedExpectations, assert.CallerInfo())
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
}
return !somethingMissing
@ -334,7 +347,7 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
// AssertNumberOfCalls asserts that the method was called expectedCalls times.
func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
var actualCalls int = 0
for _, call := range m.Calls {
for _, call := range m.calls() {
if call.Method == methodName {
actualCalls++
}
@ -345,7 +358,7 @@ func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls
// AssertCalled asserts that the method was called.
func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) {
t.Logf("%v", m.ExpectedCalls)
t.Logf("%v", m.expectedCalls())
return false
}
return true
@ -354,14 +367,14 @@ func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interfac
// AssertNotCalled asserts that the method was not called.
func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) {
t.Logf("%v", m.ExpectedCalls)
t.Logf("%v", m.expectedCalls())
return false
}
return true
}
func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
for _, call := range m.Calls {
for _, call := range m.calls() {
if call.Method == methodName {
_, differences := Arguments(expected).Diff(call.Arguments)
@ -377,6 +390,18 @@ func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
return false
}
func (m *Mock) expectedCalls() []Call {
m.mutex.Lock()
defer m.mutex.Unlock()
return append([]Call{}, m.ExpectedCalls...)
}
func (m *Mock) calls() []Call {
m.mutex.Lock()
defer m.mutex.Unlock()
return append([]Call{}, m.Calls...)
}
/*
Arguments
*/