diff --git a/suite/suite.go b/suite/suite.go index a0f0945..ac6744d 100644 --- a/suite/suite.go +++ b/suite/suite.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var matchMethod = flag.String("m", "", "regular expression to select tests of the suite to run") @@ -17,7 +18,8 @@ var matchMethod = flag.String("m", "", "regular expression to select tests of th // retrieving the current *testing.T context. type Suite struct { *assert.Assertions - t *testing.T + require *require.Assertions + t *testing.T } // T retrieves the current *testing.T context. @@ -31,6 +33,26 @@ func (suite *Suite) SetT(t *testing.T) { suite.Assertions = assert.New(t) } +// Require returns a require context for suite. +func (suite *Suite) Require() *require.Assertions { + if suite.require == nil { + suite.require = require.New(suite.T()) + } + return suite.require +} + +// Assert returns an assert context for suite. Normally, you can call +// `suite.NoError(expected, actual)`, but for situations where the embedded +// methods are overridden (for example, you might want to override +// assert.Assertions with require.Assertions), this method is provided so you +// can call `suite.Assert().NoError()`. +func (suite *Suite) Assert() *assert.Assertions { + if suite.Assertions == nil { + suite.Assertions = assert.New(suite.T()) + } + return suite.Assertions +} + // Run takes a testing suite and runs all of the tests attached // to it. func Run(t *testing.T, suite TestingSuite) { diff --git a/suite/suite_test.go b/suite/suite_test.go index 0108123..6a1bb2c 100644 --- a/suite/suite_test.go +++ b/suite/suite_test.go @@ -139,6 +139,15 @@ func TestRunSuite(t *testing.T) { } +func TestSuiteGetters(t *testing.T) { + suite := new(SuiteTester) + suite.SetT(t) + assert.NotNil(t, suite.Assert()) + assert.Equal(t, suite.Assertions, suite.Assert()) + assert.NotNil(t, suite.Require()) + assert.Equal(t, suite.require, suite.Require()) +} + type SuiteLoggingTester struct { Suite }