Compare commits

...

2 Commits

Author SHA1 Message Date
Bracken
2a57335dc9
Merge pull request #1788 from brackendawson/1785-backport-1.11
Backport #1786 to release/1.11: mock: revert to pre-v1.11.0 argument matching behavior for mutating stringers
2025-08-27 12:46:31 +02:00
Bracken
af8c91234f
Backport #1786 to release/1.11
mock: revert to pre-v1.11.0 argument matching behavior for mutating stringers
2025-08-27 12:38:08 +02:00
2 changed files with 37 additions and 55 deletions

View File

@ -948,8 +948,6 @@ func (args Arguments) Is(objects ...interface{}) bool {
return true return true
} }
type outputRenderer func() string
// Diff gets a string describing the differences between the arguments // Diff gets a string describing the differences between the arguments
// and the specified objects. // and the specified objects.
// //
@ -957,7 +955,7 @@ type outputRenderer func() string
func (args Arguments) Diff(objects []interface{}) (string, int) { func (args Arguments) Diff(objects []interface{}) (string, int) {
// TODO: could return string as error and nil for No difference // TODO: could return string as error and nil for No difference
var outputBuilder strings.Builder output := "\n"
var differences int var differences int
maxArgCount := len(args) maxArgCount := len(args)
@ -965,35 +963,24 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
maxArgCount = len(objects) maxArgCount = len(objects)
} }
outputRenderers := []outputRenderer{}
for i := 0; i < maxArgCount; i++ { for i := 0; i < maxArgCount; i++ {
i := i
var actual, expected interface{} var actual, expected interface{}
var actualFmt, expectedFmt func() string var actualFmt, expectedFmt string
if len(objects) <= i { if len(objects) <= i {
actual = "(Missing)" actual = "(Missing)"
actualFmt = func() string { actualFmt = "(Missing)"
return "(Missing)"
}
} else { } else {
actual = objects[i] actual = objects[i]
actualFmt = func() string { actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
return fmt.Sprintf("(%[1]T=%[1]v)", actual)
}
} }
if len(args) <= i { if len(args) <= i {
expected = "(Missing)" expected = "(Missing)"
expectedFmt = func() string { expectedFmt = "(Missing)"
return "(Missing)"
}
} else { } else {
expected = args[i] expected = args[i]
expectedFmt = func() string { expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
return fmt.Sprintf("(%[1]T=%[1]v)", expected)
}
} }
if matcher, ok := expected.(argumentMatcher); ok { if matcher, ok := expected.(argumentMatcher); ok {
@ -1001,22 +988,16 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
func() { func() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
actualFmt = func() string { actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
return fmt.Sprintf("panic in argument matcher: %v", r)
}
} }
}() }()
matches = matcher.Matches(actual) matches = matcher.Matches(actual)
}() }()
if matches { if matches {
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
return fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt(), matcher)
})
} else { } else {
differences++ differences++
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
return fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt(), matcher)
})
} }
} else { } else {
switch expected := expected.(type) { switch expected := expected.(type) {
@ -1025,17 +1006,13 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
// not match // not match
differences++ differences++
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt())
})
} }
case *IsTypeArgument: case *IsTypeArgument:
actualT := reflect.TypeOf(actual) actualT := reflect.TypeOf(actual)
if actualT != expected.t { if actualT != expected.t {
differences++ differences++
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected.t.Name(), actualT.Name(), actualFmt())
})
} }
case *FunctionalOptionsArgument: case *FunctionalOptionsArgument:
var name string var name string
@ -1046,36 +1023,26 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
const tName = "[]interface{}" const tName = "[]interface{}"
if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 { if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 {
differences++ differences++
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, tName, reflect.TypeOf(actual).Name(), actualFmt())
})
} else { } else {
if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" { if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" {
// match // match
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName)
return fmt.Sprintf("\t%d: PASS: %s == %s\n", i, tName, tName)
})
} else { } else {
// not match // not match
differences++ differences++
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef)
return fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, af, ef)
})
} }
} }
default: default:
if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
// match // match
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
return fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt(), expectedFmt())
})
} else { } else {
// not match // not match
differences++ differences++
outputRenderers = append(outputRenderers, func() string { output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
return fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt(), expectedFmt())
})
} }
} }
} }
@ -1086,12 +1053,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
return "No differences.", differences return "No differences.", differences
} }
outputBuilder.WriteString("\n") return output, differences
for _, r := range outputRenderers {
outputBuilder.WriteString(r())
}
return outputBuilder.String(), differences
} }
// Assert compares the arguments with the specified objects and fails if // Assert compares the arguments with the specified objects and fails if

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"runtime" "runtime"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -2421,3 +2422,22 @@ type user interface {
type mockUser struct{ Mock } type mockUser struct{ Mock }
func (m *mockUser) Use(c caller) { m.Called(c) } func (m *mockUser) Use(c caller) { m.Called(c) }
type mutatingStringer struct {
N int
s string
}
func (m *mutatingStringer) String() string {
m.s = strconv.Itoa(m.N)
return m.s
}
func TestIssue1785ArgumentWithMutatingStringer(t *testing.T) {
m := &Mock{}
m.On("Method", &mutatingStringer{N: 2})
m.On("Method", &mutatingStringer{N: 1})
m.MethodCalled("Method", &mutatingStringer{N: 1})
m.MethodCalled("Method", &mutatingStringer{N: 2})
m.AssertExpectations(t)
}