From 7767496ce0b819e144508770e2f9369887a8797f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Grzegorz=20Burzy=C5=84ski?= Date: Fri, 2 Jun 2023 23:14:55 +0200 Subject: [PATCH] fix: make assert.CollectT concurrency safe --- assert/assertions.go | 21 ++++++++++++++++++--- assert/assertions_test.go | 7 +++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/assert/assertions.go b/assert/assertions.go index a55d1bba9..434dbc0f5 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -13,6 +13,7 @@ import ( "runtime" "runtime/debug" "strings" + "sync" "time" "unicode" "unicode/utf8" @@ -1862,10 +1863,13 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t // CollectT implements the TestingT interface and collects all errors. type CollectT struct { errors []error + mu sync.RWMutex } // Errorf collects the error. func (c *CollectT) Errorf(format string, args ...interface{}) { + c.mu.Lock() + defer c.mu.Unlock() c.errors = append(c.errors, fmt.Errorf(format, args...)) } @@ -1876,6 +1880,8 @@ func (c *CollectT) FailNow() { // Reset clears the collected errors. func (c *CollectT) Reset() { + c.mu.Lock() + defer c.mu.Unlock() c.errors = nil } @@ -1884,11 +1890,20 @@ func (c *CollectT) Copy(t TestingT) { if tt, ok := t.(tHelper); ok { tt.Helper() } + c.mu.RLock() + defer c.mu.RUnlock() for _, err := range c.errors { t.Errorf("%v", err) } } +// HasErrors returns true if any errors were collected. +func (c *CollectT) HasErrors() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.errors) > 0 +} + // EventuallyWithT asserts that given condition will be met in waitFor time, // periodically checking target function each tick. In contrast to Eventually, // it supplies a CollectT to the condition function, so that the condition @@ -1931,10 +1946,10 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time collect.Reset() go func() { condition(collect) - ch <- len(collect.errors) == 0 + ch <- collect.HasErrors() }() - case v := <-ch: - if v { + case hasErrors := <-ch: + if !hasErrors { return true } tick = ticker.C diff --git a/assert/assertions_test.go b/assert/assertions_test.go index acd4e59ab..37759fd4d 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -2786,6 +2786,13 @@ func TestEventuallyWithTTrue(t *testing.T) { Len(t, mockT.errors, 0) } +func TestEventuallyWithT_ConcurrencySafe(t *testing.T) { + mockT := new(CollectT) + EventuallyWithT(mockT, func(c *CollectT) { + NoError(c, AnError) + }, time.Millisecond, time.Nanosecond) +} + func TestNeverFalse(t *testing.T) { condition := func() bool { return false