From 2aa2c176b9dab406a6970f6a55f513e8a8c8b18f Mon Sep 17 00:00:00 2001 From: Dan Heller Date: Mon, 14 Aug 2017 13:04:35 -0700 Subject: [PATCH] Fix unprotected call fields access in MethodCalled() This change fixes a race condition I discovered when a multithreaded test in a service I work on failed under -race. The included test case simulates that failure (concurrent mutation of a Call with invocations on the mock). The test will fail with a data race if run under the race detector; the new locking ensures that call fields are not accessed without the protection of the parent mutex. --- mock/mock.go | 14 +++++++++++--- mock/mock_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index fc63571..5c17124 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -336,11 +336,19 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen <-call.WaitFor } - if call.RunFn != nil { - call.RunFn(arguments) + m.mutex.Lock() + runFn := call.RunFn + m.mutex.Unlock() + + if runFn != nil { + runFn(arguments) } - return call.ReturnArguments + m.mutex.Lock() + returnArgs := call.ReturnArguments + m.mutex.Unlock() + + return returnArgs } /* diff --git a/mock/mock_test.go b/mock/mock_test.go index c050236..b4501f9 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -2,6 +2,7 @@ package mock import ( "errors" + "sync" "testing" "time" @@ -1195,3 +1196,31 @@ func Test_MockMethodCalled(t *testing.T) { require.Equal(t, "world", retArgs[0]) m.AssertExpectations(t) } + +// Test to validate fix for racy concurrent call access in MethodCalled() +func Test_MockReturnAndCalledConcurrent(t *testing.T) { + iterations := 1000 + m := &Mock{} + call := m.On("ConcurrencyTestMethod") + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + for i := 0; i < iterations; i++ { + call.Return(10) + } + wg.Done() + }() + go func() { + for i := 0; i < iterations; i++ { + ConcurrencyTestMethod(m) + } + wg.Done() + }() + wg.Wait() +} + +func ConcurrencyTestMethod(m *Mock) { + m.Called() +}