fix: make EventuallyWithT concurrency safe

pull/1484/head
Grzegorz Burzyński 2023-06-02 23:14:55 +02:00 committed by Olivier Mengué
parent 11a6452626
commit 4ed68e1bca
2 changed files with 67 additions and 22 deletions

View File

@ -1873,23 +1873,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
}
// FailNow panics.
func (c *CollectT) FailNow() {
func (*CollectT) FailNow() {
panic("Assertion failed")
}
// Reset clears the collected errors.
func (c *CollectT) Reset() {
c.errors = nil
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (*CollectT) Reset() {
panic("Reset() is deprecated")
}
// Copy copies the collected errors to the supplied t.
func (c *CollectT) Copy(t TestingT) {
if tt, ok := t.(tHelper); ok {
tt.Helper()
}
for _, err := range c.errors {
t.Errorf("%v", err)
}
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (*CollectT) Copy(TestingT) {
panic("Copy() is deprecated")
}
// EventuallyWithT asserts that given condition will be met in waitFor time,
@ -1915,8 +1910,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
h.Helper()
}
collect := new(CollectT)
ch := make(chan bool, 1)
var lastFinishedTickErrs []error
ch := make(chan []error, 1)
timer := time.NewTimer(waitFor)
defer timer.Stop()
@ -1927,19 +1922,23 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
for tick := ticker.C; ; {
select {
case <-timer.C:
collect.Copy(t)
for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick:
tick = nil
collect.Reset()
go func() {
collect := new(CollectT)
condition(collect)
ch <- len(collect.errors) == 0
ch <- collect.errors
}()
case v := <-ch:
if v {
case errs := <-ch:
if len(errs) == 0 {
return true
}
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs
tick = ticker.C
}
}

View File

@ -2766,11 +2766,22 @@ func TestEventuallyTrue(t *testing.T) {
True(t, Eventually(t, condition, 100*time.Millisecond, 20*time.Millisecond))
}
// errorsCapturingT is a mock implementation of TestingT that captures errors reported with Errorf.
type errorsCapturingT struct {
errors []error
}
func (t *errorsCapturingT) Errorf(format string, args ...interface{}) {
t.errors = append(t.errors, fmt.Errorf(format, args...))
}
func (t *errorsCapturingT) Helper() {}
func TestEventuallyWithTFalse(t *testing.T) {
mockT := new(CollectT)
mockT := new(errorsCapturingT)
condition := func(collect *CollectT) {
True(collect, false)
Fail(collect, "condition fixed failure")
}
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
@ -2778,7 +2789,7 @@ func TestEventuallyWithTFalse(t *testing.T) {
}
func TestEventuallyWithTTrue(t *testing.T) {
mockT := new(CollectT)
mockT := new(errorsCapturingT)
state := 0
condition := func(collect *CollectT) {
@ -2792,6 +2803,41 @@ func TestEventuallyWithTTrue(t *testing.T) {
Len(t, mockT.errors, 0)
}
func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
mockT := new(errorsCapturingT)
condition := func(collect *CollectT) {
Fail(collect, "condition fixed failure")
}
// To trigger race conditions, we run EventuallyWithT with a nanosecond tick.
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Nanosecond))
Len(t, mockT.errors, 2)
}
func TestEventuallyWithT_ReturnsTheLatestFinishedConditionErrors(t *testing.T) {
// We'll use a channel to control whether a condition should sleep or not.
mustSleep := make(chan bool, 2)
mustSleep <- false
mustSleep <- true
close(mustSleep)
condition := func(collect *CollectT) {
if <-mustSleep {
// Sleep to ensure that the second condition runs longer than timeout.
time.Sleep(time.Second)
return
}
// The first condition will fail. We expect to get this error as a result.
Fail(collect, "condition fixed failure")
}
mockT := new(errorsCapturingT)
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
Len(t, mockT.errors, 2)
}
func TestNeverFalse(t *testing.T) {
condition := func() bool {
return false