diff --git a/mock/mock.go b/mock/mock.go index 30fcf3d..a6d9526 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -65,6 +65,11 @@ type Call struct { // reference. It's useful when mocking methods such as unmarshalers or // decoders. RunFn func(Arguments) + + // PanicMsg holds msg to be used to mock panic on the function call + // if the PanicMsg is set to a non nil string the function call will panic + // irrespective of other settings + PanicMsg *string } func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call { @@ -77,6 +82,7 @@ func newCall(parent *Mock, methodName string, callerInfo []string, methodArgumen Repeatability: 0, WaitFor: nil, RunFn: nil, + PanicMsg: nil, } } @@ -100,6 +106,18 @@ func (c *Call) Return(returnArguments ...interface{}) *Call { return c } +// Panic specifies if the functon call should fail and the panic message +// +// Mock.On("DoSomething").Panic("test panic") +func (c *Call) Panic(msg string) *Call { + c.lock() + defer c.unlock() + + c.PanicMsg = &msg + + return c +} + // Once indicates that that the mock should only return the value once. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() @@ -392,6 +410,13 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen time.Sleep(call.waitTime) } + m.mutex.Lock() + panicMsg := call.PanicMsg + m.mutex.Unlock() + if panicMsg != nil { + panic(*panicMsg) + } + m.mutex.Lock() runFn := call.RunFn m.mutex.Unlock() diff --git a/mock/mock_test.go b/mock/mock_test.go index fcbc4f6..1b681aa 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -486,6 +486,29 @@ func Test_Mock_Return(t *testing.T) { assert.Nil(t, call.WaitFor) } +func Test_Mock_Panic(t *testing.T) { + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + c := mockedService. + On("TheExampleMethod", "A", "B", true). + Panic("panic message for example method") + + require.Equal(t, []*Call{c}, mockedService.ExpectedCalls) + + call := mockedService.ExpectedCalls[0] + + assert.Equal(t, "TheExampleMethod", call.Method) + assert.Equal(t, "A", call.Arguments[0]) + assert.Equal(t, "B", call.Arguments[1]) + assert.Equal(t, true, call.Arguments[2]) + assert.Equal(t, 0, call.Repeatability) + assert.Equal(t, 0, call.Repeatability) + assert.Equal(t, "panic message for example method", *call.PanicMsg) + assert.Nil(t, call.WaitFor) +} + func Test_Mock_Return_WaitUntil(t *testing.T) { // make a test impl object @@ -1420,6 +1443,14 @@ func Test_MockMethodCalled(t *testing.T) { m.AssertExpectations(t) } +func Test_MockMethodCalled_Panic(t *testing.T) { + m := new(Mock) + m.On("foo", "hello").Panic("world panics") + + require.PanicsWithValue(t, "world panics", func() { m.MethodCalled("foo", "hello") }) + m.AssertExpectations(t) +} + // Test to validate fix for racy concurrent call access in MethodCalled() func Test_MockReturnAndCalledConcurrent(t *testing.T) { iterations := 1000