diff --git a/assert/assertions.go b/assert/assertions.go index 818cd7b..646d23d 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -10,6 +10,8 @@ import ( "runtime" "strings" "time" + "unicode" + "unicode/utf8" ) // TestingT is an interface wrapper around *testing.T @@ -64,28 +66,62 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { internally, causing it to print the file:line of the assert method, rather than where the problem actually occured in calling code.*/ -// CallerInfo returns a string containing the file and line number of the assert call -// that failed. -func CallerInfo() string { +// CallerInfo returns an array of strings containing the file and line number +// of each stack frame leading from the current test to the assert call that +// failed. +func CallerInfo() []string { + pc := uintptr(0) file := "" line := 0 ok := false + name := "" + callers := []string{} for i := 0; ; i++ { - _, file, line, ok = runtime.Caller(i) + pc, file, line, ok = runtime.Caller(i) if !ok { - return "" + return nil } + parts := strings.Split(file, "/") dir := parts[len(parts)-2] file = parts[len(parts)-1] if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" { + callers = append([]string{fmt.Sprintf("%s:%d", file, line)}, callers...) + } + + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + // Drop the package + segments := strings.Split(name, ".") + name = segments[len(segments)-1] + if isTest(name, "Test") || + isTest(name, "Benchmark") || + isTest(name, "Example") { break } } - return fmt.Sprintf("%s:%d", file, line) + return callers +} + +// Stolen from the `go test` tool. +// isTest tells whether name looks like a test (or benchmark, according to prefix). +// It is a Test (say) if there is a character after Test that is not a lower-case letter. +// We don't want TesticularCancer. +func isTest(name, prefix string) bool { + if !strings.HasPrefix(name, prefix) { + return false + } + if len(name) == len(prefix) { // "Test" is ok + return true + } + rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) + return !unicode.IsLower(rune) } // getWhitespaceString returns a string that is long enough to overwrite the default @@ -144,19 +180,20 @@ func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { message := messageFromMsgAndArgs(msgAndArgs...) + locationInfo := strings.Join(CallerInfo(), "\n\r\t\t\t") if len(message) > 0 { t.Errorf("\r%s\r\tLocation:\t%s\n"+ "\r\tError:%s\n"+ "\r\tMessages:\t%s\n\r", getWhitespaceString(), - CallerInfo(), + locationInfo, indentMessageLines(failureMessage, 2), message) } else { t.Errorf("\r%s\r\tLocation:\t%s\n"+ "\r\tError:%s\n\r", getWhitespaceString(), - CallerInfo(), + locationInfo, indentMessageLines(failureMessage, 2)) }