diff --git a/_codegen/main.go b/_codegen/main.go index 3f08151..9eae052 100644 --- a/_codegen/main.go +++ b/_codegen/main.go @@ -35,61 +35,94 @@ var ( func main() { flag.Parse() - tmplHead, err := template.New("header").Parse(defaultTemplate) + scope, docs, err := parsePackageSource(*pkg) if err != nil { log.Fatal(err) } + + importer, funcs, err := analyzeCode(scope, docs) + if err != nil { + log.Fatal(err) + } + + if err := generateCode(importer, funcs); err != nil { + log.Fatal() + } +} + +func generateCode(importer imports.Importer, funcs []Func) error { + buff := bytes.NewBuffer(nil) + + tmplHead, tmplFunc, err := parseTemplates() + if err != nil { + return err + } + + // Generate header + if err := tmplHead.Execute(buff, struct { + Name string + Imports map[string]string + }{ + *outputPkg, + importer.Imports(), + }); err != nil { + return err + } + + // Generate funcs + for _, fn := range funcs { + buff.Write([]byte("\n\n")) + if err := tmplFunc.Execute(buff, &fn); err != nil { + return err + } + } + + // Write file + output, err := outputFile() + if err != nil { + return err + } + defer output.Close() + _, err = io.Copy(output, buff) + return err +} + +func parseTemplates() (*template.Template, *template.Template, error) { + tmplHead, err := template.New("header").Parse(headerTemplate) + if err != nil { + return nil, nil, err + } if *tmplFile != "" { f, err := ioutil.ReadFile(*tmplFile) if err != nil { - log.Fatal(err) + return nil, nil, err } funcTemplate = string(f) } tmpl, err := template.New("function").Parse(funcTemplate) if err != nil { - log.Fatal(err) + return nil, nil, err } + return tmplHead, tmpl, nil +} - pd, err := build.Import(*pkg, ".", 0) - if err != nil { - log.Fatal(err) +func outputFile() (*os.File, error) { + filename := *out + if filename == "-" || (filename == "" && *tmplFile == "") { + return os.Stdout, nil } + if filename == "" { + filename = strings.TrimSuffix(strings.TrimSuffix(*tmplFile, ".tmpl"), ".go") + ".go" + } + return os.Create(filename) +} - fset := token.NewFileSet() - files := make(map[string]*ast.File) - fileList := make([]*ast.File, len(pd.GoFiles)) - for i, fname := range pd.GoFiles { - src, err := ioutil.ReadFile(path.Join(pd.SrcRoot, pd.ImportPath, fname)) - if err != nil { - log.Fatal(err) - } - f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors) - if err != nil { - log.Fatal(err) - } - files[fname] = f - fileList[i] = f - } - - cfg := types.Config{ - Importer: importer.Default(), - } - info := types.Info{ - Defs: make(map[*ast.Ident]types.Object), - } - tp, err := cfg.Check(*pkg, fset, fileList, &info) - if err != nil { - log.Fatal(err) - } - - scope := tp.Scope() +// analyzeCode takes the types scope and the docs and returns the import +// information and information about all the assertion functions. +func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []Func, error) { testingT := scope.Lookup("TestingT").Type().Underlying().(*types.Interface) - ap, _ := ast.NewPackage(fset, files, nil, nil) - docs := doc.New(ap, *pkg, 0) - - imports := imports.New(*outputPkg) + importer := imports.New(*outputPkg) funcs := make([]Func, 0) // Go through all the top level functions for _, fdocs := range docs.Funcs { @@ -119,61 +152,52 @@ func main() { } funcs = append(funcs, Func{*outputPkg, fdocs, fn}) - imports.AddImportsFrom(sig.Params()) + importer.AddImportsFrom(sig.Params()) } - - buff := bytes.NewBuffer(nil) - - if err := tmplHead.Execute(buff, struct { - Name string - Imports map[string]string - }{ - *outputPkg, - imports.Imports(), - }); err != nil { - log.Fatal(err) - } - for _, fn := range funcs { - buff.Write([]byte("\n\n")) - if err := tmpl.Execute(buff, &fn); err != nil { - log.Fatal(err) - } - } - - var output *os.File - if *out == "-" || (*out == "" && *tmplFile == "") { - *out = "-" - output = os.Stdout - } else if *out == "" { - *out = strings.TrimSuffix(strings.TrimSuffix(*tmplFile, ".tmpl"), ".go") + ".go" - } - if *out != "-" { - output, err = os.Create(*out) - if err != nil { - log.Fatal(err) - } - defer output.Close() - } - io.Copy(output, buff) + return importer, funcs, nil } -var defaultTemplate = `/* -* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen -* THIS FILE MUST NOT BE EDITED BY HAND -*/ +// parsePackageSource returns the types scope and the package documentation from the pa +func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) { + pd, err := build.Import(pkg, ".", 0) + if err != nil { + return nil, nil, err + } -package {{.Name}} + fset := token.NewFileSet() + files := make(map[string]*ast.File) + fileList := make([]*ast.File, len(pd.GoFiles)) + for i, fname := range pd.GoFiles { + src, err := ioutil.ReadFile(path.Join(pd.SrcRoot, pd.ImportPath, fname)) + if err != nil { + return nil, nil, err + } + f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors) + if err != nil { + return nil, nil, err + } + files[fname] = f + fileList[i] = f + } -import ( -{{range $path, $name := .Imports}} - {{$name}} "{{$path}}"{{end}} -) -` + cfg := types.Config{ + Importer: importer.Default(), + } + info := types.Info{ + Defs: make(map[*ast.Ident]types.Object), + } + tp, err := cfg.Check(pkg, fset, fileList, &info) + if err != nil { + return nil, nil, err + } -var funcTemplate = `{{.Comment}} -func (fwd *AssertionsForwarder) {{.DocInfo.Name}}({{.Params}}) bool { - return assert.{{.DocInfo.Name}}({{.ForwardedParams}}) -}` + scope := tp.Scope() + + ap, _ := ast.NewPackage(fset, files, nil, nil) + docs := doc.New(ap, pkg, 0) + + return scope, docs, nil +} type Func struct { CurrentPkg string @@ -237,3 +261,21 @@ func (f *Func) ForwardedParams() string { func (f *Func) Comment() string { return "// " + strings.Replace(strings.TrimSpace(f.DocInfo.Doc), "\n", "\n// ", -1) } + +var headerTemplate = `/* +* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen +* THIS FILE MUST NOT BE EDITED BY HAND +*/ + +package {{.Name}} + +import ( +{{range $path, $name := .Imports}} + {{$name}} "{{$path}}"{{end}} +) +` + +var funcTemplate = `{{.Comment}} +func (fwd *AssertionsForwarder) {{.DocInfo.Name}}({{.Params}}) bool { + return assert.{{.DocInfo.Name}}({{.ForwardedParams}}) +}`