From b93dd46367ed93b4a22882653381e1d6932d3446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Grzegorz=20Burzy=C5=84ski?= Date: Fri, 2 Jun 2023 23:14:55 +0200 Subject: [PATCH] fix: make assert.CollectT concurrency safe --- assert/assertions.go | 20 +++++++++++++------- assert/assertions_test.go | 12 ++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/assert/assertions.go b/assert/assertions.go index 6ab0ec347..1793caeba 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -13,6 +13,7 @@ import ( "runtime" "runtime/debug" "strings" + "sync" "time" "unicode" "unicode/utf8" @@ -1862,6 +1863,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t // CollectT implements the TestingT interface and collects all errors. type CollectT struct { errors []error + mu sync.RWMutex } // Errorf collects the error. @@ -1912,8 +1914,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time h.Helper() } - collect := new(CollectT) - ch := make(chan bool, 1) + var lastTickErrs []error + ch := make(chan []error, 1) timer := time.NewTimer(waitFor) defer timer.Stop() @@ -1924,19 +1926,23 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time for tick := ticker.C; ; { select { case <-timer.C: - collect.Copy(t) + for _, err := range lastTickErrs { + t.Errorf("%v", err) + } return Fail(t, "Condition never satisfied", msgAndArgs...) case <-tick: tick = nil - collect.Reset() go func() { + collect := new(CollectT) condition(collect) - ch <- len(collect.errors) == 0 + ch <- collect.errors }() - case v := <-ch: - if v { + case errs := <-ch: + if len(errs) == 0 { return true } + // Keep the last tick's errors, so that they can be copied to t if the condition is not met on time. + lastTickErrs = errs tick = ticker.C } } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 162c71801..118983673 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -2786,6 +2786,18 @@ func TestEventuallyWithTTrue(t *testing.T) { Len(t, mockT.errors, 0) } +func TestEventuallyWithT_ConcurrencySafe(t *testing.T) { + mockT := new(CollectT) + + condition := func(collect *CollectT) { + True(collect, false) + } + + // To trigger race conditions, we run EventuallyWithT with a nanosecond tick. + False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Nanosecond)) + Len(t, mockT.errors, 2) +} + func TestNeverFalse(t *testing.T) { condition := func() bool { return false