mirror of https://github.com/stretchr/testify.git
mock: fix races to m.Calls/m.ExpectedCalls
Fixes races inside the mock package caused by unsychronized reads/writes of m.Calls/m.ExpectedCalls.pull/177/head
parent
3c81d9b268
commit
b11fb16915
71
mock/mock.go
71
mock/mock.go
|
@ -2,13 +2,14 @@ package mock
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/objx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/objx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestingT is an interface wrapper around *testing.T
|
||||
|
@ -123,14 +124,18 @@ func (m *Mock) Return(returnArguments ...interface{}) *Mock {
|
|||
//
|
||||
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
|
||||
func (m *Mock) Once() {
|
||||
m.mutex.Lock()
|
||||
m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = 1
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Twice indicates that that the mock should only return the value twice.
|
||||
//
|
||||
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
|
||||
func (m *Mock) Twice() {
|
||||
m.mutex.Lock()
|
||||
m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = 2
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Times indicates that that the mock should only return the indicated number
|
||||
|
@ -138,7 +143,9 @@ func (m *Mock) Twice() {
|
|||
//
|
||||
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
|
||||
func (m *Mock) Times(i int) {
|
||||
m.mutex.Lock()
|
||||
m.ExpectedCalls[len(m.ExpectedCalls)-1].Repeatability = i
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// WaitUntil sets the channel that will block the mock's return until its closed
|
||||
|
@ -146,7 +153,9 @@ func (m *Mock) Times(i int) {
|
|||
//
|
||||
// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
|
||||
func (m *Mock) WaitUntil(w <-chan time.Time) *Mock {
|
||||
m.mutex.Lock()
|
||||
m.ExpectedCalls[len(m.ExpectedCalls)-1].WaitFor = w
|
||||
m.mutex.Unlock()
|
||||
return m
|
||||
}
|
||||
|
||||
|
@ -166,7 +175,9 @@ func (m *Mock) After(d time.Duration) *Mock {
|
|||
// arg["foo"] = "bar"
|
||||
// })
|
||||
func (m *Mock) Run(fn func(Arguments)) *Mock {
|
||||
m.mutex.Lock()
|
||||
m.ExpectedCalls[len(m.ExpectedCalls)-1].Run = fn
|
||||
m.mutex.Unlock()
|
||||
return m
|
||||
}
|
||||
|
||||
|
@ -175,7 +186,7 @@ func (m *Mock) Run(fn func(Arguments)) *Mock {
|
|||
*/
|
||||
|
||||
func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
|
||||
for i, call := range m.ExpectedCalls {
|
||||
for i, call := range m.expectedCalls() {
|
||||
if call.Method == method && call.Repeatability > -1 {
|
||||
|
||||
_, diffCount := call.Arguments.Diff(arguments)
|
||||
|
@ -189,11 +200,10 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *
|
|||
}
|
||||
|
||||
func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) {
|
||||
|
||||
diffCount := 0
|
||||
var closestCall *Call = nil
|
||||
|
||||
for _, call := range m.ExpectedCalls {
|
||||
for _, call := range m.expectedCalls() {
|
||||
if call.Method == method {
|
||||
|
||||
_, tempDiffCount := call.Arguments.Diff(arguments)
|
||||
|
@ -231,9 +241,6 @@ func callString(method string, arguments Arguments, includeArgumentValues bool)
|
|||
// appropriate .On .Return() calls)
|
||||
// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
|
||||
func (m *Mock) Called(arguments ...interface{}) Arguments {
|
||||
defer m.mutex.Unlock()
|
||||
m.mutex.Lock()
|
||||
|
||||
// get the calling function's name
|
||||
pc, _, _, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
|
@ -245,8 +252,7 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
|
|||
|
||||
found, call := m.findExpectedCall(functionName, arguments...)
|
||||
|
||||
switch {
|
||||
case found < 0:
|
||||
if found < 0 {
|
||||
// we have to fail here - because we don't know what to do
|
||||
// as the return arguments. This is because:
|
||||
//
|
||||
|
@ -261,16 +267,23 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
|
|||
} else {
|
||||
panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo()))
|
||||
}
|
||||
case call.Repeatability == 1:
|
||||
call.Repeatability = -1
|
||||
m.ExpectedCalls[found] = *call
|
||||
case call.Repeatability > 1:
|
||||
call.Repeatability -= 1
|
||||
m.ExpectedCalls[found] = *call
|
||||
} else {
|
||||
m.mutex.Lock()
|
||||
switch {
|
||||
case call.Repeatability == 1:
|
||||
call.Repeatability = -1
|
||||
m.ExpectedCalls[found] = *call
|
||||
case call.Repeatability > 1:
|
||||
call.Repeatability -= 1
|
||||
m.ExpectedCalls[found] = *call
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// add the call
|
||||
m.mutex.Lock()
|
||||
m.Calls = append(m.Calls, Call{functionName, arguments, make([]interface{}, 0), 0, nil, nil})
|
||||
m.mutex.Unlock()
|
||||
|
||||
// block if specified
|
||||
if call.WaitFor != nil {
|
||||
|
@ -305,12 +318,12 @@ func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
|
|||
// AssertExpectations asserts that everything specified with On and Return was
|
||||
// in fact called as expected. Calls may have occurred in any order.
|
||||
func (m *Mock) AssertExpectations(t TestingT) bool {
|
||||
|
||||
var somethingMissing bool = false
|
||||
var failedExpectations int = 0
|
||||
|
||||
// iterate through each expectation
|
||||
for _, expectedCall := range m.ExpectedCalls {
|
||||
expectedCalls := m.expectedCalls()
|
||||
for _, expectedCall := range expectedCalls {
|
||||
switch {
|
||||
case !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments):
|
||||
somethingMissing = true
|
||||
|
@ -325,7 +338,7 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
|
|||
}
|
||||
|
||||
if somethingMissing {
|
||||
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(m.ExpectedCalls)-failedExpectations, len(m.ExpectedCalls), failedExpectations, assert.CallerInfo())
|
||||
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
|
||||
}
|
||||
|
||||
return !somethingMissing
|
||||
|
@ -334,7 +347,7 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
|
|||
// AssertNumberOfCalls asserts that the method was called expectedCalls times.
|
||||
func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
|
||||
var actualCalls int = 0
|
||||
for _, call := range m.Calls {
|
||||
for _, call := range m.calls() {
|
||||
if call.Method == methodName {
|
||||
actualCalls++
|
||||
}
|
||||
|
@ -345,7 +358,7 @@ func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls
|
|||
// AssertCalled asserts that the method was called.
|
||||
func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
|
||||
if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) {
|
||||
t.Logf("%v", m.ExpectedCalls)
|
||||
t.Logf("%v", m.expectedCalls())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
@ -354,14 +367,14 @@ func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interfac
|
|||
// AssertNotCalled asserts that the method was not called.
|
||||
func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
|
||||
if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) {
|
||||
t.Logf("%v", m.ExpectedCalls)
|
||||
t.Logf("%v", m.expectedCalls())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
|
||||
for _, call := range m.Calls {
|
||||
for _, call := range m.calls() {
|
||||
if call.Method == methodName {
|
||||
|
||||
_, differences := Arguments(expected).Diff(call.Arguments)
|
||||
|
@ -377,6 +390,18 @@ func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (m *Mock) expectedCalls() []Call {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return append([]Call{}, m.ExpectedCalls...)
|
||||
}
|
||||
|
||||
func (m *Mock) calls() []Call {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return append([]Call{}, m.Calls...)
|
||||
}
|
||||
|
||||
/*
|
||||
Arguments
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue