fix data race in the suit

Signed-off-by: Weizhen Wang <wangweizhen@pingcap.com>
pull/1186/head
Weizhen Wang 2022-03-03 23:43:17 +08:00 committed by Boyan Soubachov
parent 35864782d2
commit a409ccf19e
1 changed files with 10 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"runtime/debug" "runtime/debug"
"sync"
"testing" "testing"
"time" "time"
@ -21,17 +22,22 @@ var matchMethod = flag.String("testify.m", "", "regular expression to select tes
// retrieving the current *testing.T context. // retrieving the current *testing.T context.
type Suite struct { type Suite struct {
*assert.Assertions *assert.Assertions
mu sync.Mutex
require *require.Assertions require *require.Assertions
t *testing.T t *testing.T
} }
// T retrieves the current *testing.T context. // T retrieves the current *testing.T context.
func (suite *Suite) T() *testing.T { func (suite *Suite) T() *testing.T {
suite.mu.Lock()
defer suite.mu.Unlock()
return suite.t return suite.t
} }
// SetT sets the current *testing.T context. // SetT sets the current *testing.T context.
func (suite *Suite) SetT(t *testing.T) { func (suite *Suite) SetT(t *testing.T) {
suite.mu.Lock()
defer suite.mu.Unlock()
suite.t = t suite.t = t
suite.Assertions = assert.New(t) suite.Assertions = assert.New(t)
suite.require = require.New(t) suite.require = require.New(t)
@ -39,6 +45,8 @@ func (suite *Suite) SetT(t *testing.T) {
// Require returns a require context for suite. // Require returns a require context for suite.
func (suite *Suite) Require() *require.Assertions { func (suite *Suite) Require() *require.Assertions {
suite.mu.Lock()
defer suite.mu.Unlock()
if suite.require == nil { if suite.require == nil {
suite.require = require.New(suite.T()) suite.require = require.New(suite.T())
} }
@ -51,6 +59,8 @@ func (suite *Suite) Require() *require.Assertions {
// assert.Assertions with require.Assertions), this method is provided so you // assert.Assertions with require.Assertions), this method is provided so you
// can call `suite.Assert().NoError()`. // can call `suite.Assert().NoError()`.
func (suite *Suite) Assert() *assert.Assertions { func (suite *Suite) Assert() *assert.Assertions {
suite.mu.Lock()
defer suite.mu.Unlock()
if suite.Assertions == nil { if suite.Assertions == nil {
suite.Assertions = assert.New(suite.T()) suite.Assertions = assert.New(suite.T())
} }