diff --git a/assert/assertions.go b/assert/assertions.go index a9318cf..e4ddc8c 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -1949,6 +1949,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t } ch := make(chan bool, 1) + checkCond := func() { ch <- condition() } timer := time.NewTimer(waitFor) defer timer.Stop() @@ -1956,18 +1957,23 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t ticker := time.NewTicker(tick) defer ticker.Stop() - for tick := ticker.C; ; { + var tickC <-chan time.Time + + // Check the condition once first on the initial call. + go checkCond() + + for { select { case <-timer.C: return Fail(t, "Condition never satisfied", msgAndArgs...) - case <-tick: - tick = nil - go func() { ch <- condition() }() + case <-tickC: + tickC = nil + go checkCond() case v := <-ch: if v { return true } - tick = ticker.C + tickC = ticker.C } } } @@ -2037,35 +2043,42 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time var lastFinishedTickErrs []error ch := make(chan *CollectT, 1) + checkCond := func() { + collect := new(CollectT) + defer func() { + ch <- collect + }() + condition(collect) + } + timer := time.NewTimer(waitFor) defer timer.Stop() ticker := time.NewTicker(tick) defer ticker.Stop() - for tick := ticker.C; ; { + var tickC <-chan time.Time + + // Check the condition once first on the initial call. + go checkCond() + + for { select { case <-timer.C: for _, err := range lastFinishedTickErrs { t.Errorf("%v", err) } return Fail(t, "Condition never satisfied", msgAndArgs...) - case <-tick: - tick = nil - go func() { - collect := new(CollectT) - defer func() { - ch <- collect - }() - condition(collect) - }() + case <-tickC: + tickC = nil + go checkCond() case collect := <-ch: if !collect.failed() { return true } // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. lastFinishedTickErrs = collect.errors - tick = ticker.C + tickC = ticker.C } } } @@ -2080,6 +2093,7 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D } ch := make(chan bool, 1) + checkCond := func() { ch <- condition() } timer := time.NewTimer(waitFor) defer timer.Stop() @@ -2087,18 +2101,23 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D ticker := time.NewTicker(tick) defer ticker.Stop() - for tick := ticker.C; ; { + var tickC <-chan time.Time + + // Check the condition once first on the initial call. + go checkCond() + + for { select { case <-timer.C: return true - case <-tick: - tick = nil - go func() { ch <- condition() }() + case <-tickC: + tickC = nil + go checkCond() case v := <-ch: if v { return Fail(t, "Condition satisfied", msgAndArgs...) } - tick = ticker.C + tickC = ticker.C } } } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 9d27d1b..be7d930 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -3058,6 +3058,49 @@ func TestEventuallyWithTFailNow(t *testing.T) { Len(t, mockT.errors, 1) } +// Check that a long running condition doesn't block Eventually. +// See issue 805 (and its long tail of following issues) +func TestEventuallyTimeout(t *testing.T) { + mockT := new(testing.T) + + NotPanics(t, func() { + done, done2 := make(chan struct{}), make(chan struct{}) + + // A condition function that returns after the Eventually timeout + condition := func() bool { + // Wait until Eventually times out and terminates + <-done + close(done2) + return true + } + + False(t, Eventually(mockT, condition, time.Millisecond, time.Microsecond)) + + close(done) + <-done2 + }) +} + +func TestEventuallySucceedQuickly(t *testing.T) { + mockT := new(testing.T) + + condition := func() bool { return true } + + // By making the tick longer than the total duration, we expect that this test would fail if + // we didn't check the condition before the first tick elapses. + True(t, Eventually(mockT, condition, 100*time.Millisecond, time.Second)) +} + +func TestEventuallyWithTSucceedQuickly(t *testing.T) { + mockT := new(testing.T) + + condition := func(t *CollectT) {} + + // By making the tick longer than the total duration, we expect that this test would fail if + // we didn't check the condition before the first tick elapses. + True(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Second)) +} + func TestNeverFalse(t *testing.T) { condition := func() bool { return false @@ -3085,27 +3128,13 @@ func TestNeverTrue(t *testing.T) { False(t, Never(mockT, condition, 100*time.Millisecond, 20*time.Millisecond)) } -// Check that a long running condition doesn't block Eventually. -// See issue 805 (and its long tail of following issues) -func TestEventuallyTimeout(t *testing.T) { +func TestNeverFailQuickly(t *testing.T) { mockT := new(testing.T) - NotPanics(t, func() { - done, done2 := make(chan struct{}), make(chan struct{}) - - // A condition function that returns after the Eventually timeout - condition := func() bool { - // Wait until Eventually times out and terminates - <-done - close(done2) - return true - } - - False(t, Eventually(mockT, condition, time.Millisecond, time.Microsecond)) - - close(done) - <-done2 - }) + // By making the tick longer than the total duration, we expect that this test would fail if + // we didn't check the condition before the first tick elapses. + condition := func() bool { return true } + False(t, Never(mockT, condition, 100*time.Millisecond, time.Second)) } func Test_validateEqualArgs(t *testing.T) { diff --git a/mock/mock.go b/mock/mock.go index c81a0bd..73e020f 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -948,6 +948,8 @@ func (args Arguments) Is(objects ...interface{}) bool { return true } +type outputRenderer func() string + // Diff gets a string describing the differences between the arguments // and the specified objects. // @@ -955,7 +957,7 @@ func (args Arguments) Is(objects ...interface{}) bool { func (args Arguments) Diff(objects []interface{}) (string, int) { // TODO: could return string as error and nil for No difference - output := "\n" + var outputBuilder strings.Builder var differences int maxArgCount := len(args) @@ -963,24 +965,35 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { maxArgCount = len(objects) } + outputRenderers := []outputRenderer{} + for i := 0; i < maxArgCount; i++ { + i := i var actual, expected interface{} - var actualFmt, expectedFmt string + var actualFmt, expectedFmt func() string if len(objects) <= i { actual = "(Missing)" - actualFmt = "(Missing)" + actualFmt = func() string { + return "(Missing)" + } } else { actual = objects[i] - actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + actualFmt = func() string { + return fmt.Sprintf("(%[1]T=%[1]v)", actual) + } } if len(args) <= i { expected = "(Missing)" - expectedFmt = "(Missing)" + expectedFmt = func() string { + return "(Missing)" + } } else { expected = args[i] - expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + expectedFmt = func() string { + return fmt.Sprintf("(%[1]T=%[1]v)", expected) + } } if matcher, ok := expected.(argumentMatcher); ok { @@ -988,16 +1001,22 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { func() { defer func() { if r := recover(); r != nil { - actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) + actualFmt = func() string { + return fmt.Sprintf("panic in argument matcher: %v", r) + } } }() matches = matcher.Matches(actual) }() if matches { - output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt(), matcher) + }) } else { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt(), matcher) + }) } } else { switch expected := expected.(type) { @@ -1006,13 +1025,17 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { // not match differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt()) + }) } case *IsTypeArgument: actualT := reflect.TypeOf(actual) if actualT != expected.t { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected.t.Name(), actualT.Name(), actualFmt()) + }) } case *FunctionalOptionsArgument: var name string @@ -1023,26 +1046,36 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { const tName = "[]interface{}" if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, tName, reflect.TypeOf(actual).Name(), actualFmt()) + }) } else { if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" { // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: PASS: %s == %s\n", i, tName, tName) + }) } else { // not match differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, af, ef) + }) } } default: if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt(), expectedFmt()) + }) } else { // not match differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) + outputRenderers = append(outputRenderers, func() string { + return fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt(), expectedFmt()) + }) } } } @@ -1053,7 +1086,12 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { return "No differences.", differences } - return output, differences + outputBuilder.WriteString("\n") + for _, r := range outputRenderers { + outputBuilder.WriteString(r()) + } + + return outputBuilder.String(), differences } // Assert compares the arguments with the specified objects and fails if