diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 2b76f11..dac66c5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,7 +21,7 @@ jobs: uses: golangci/golangci-lint-action@v3 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: v1.50.1 + version: latest # Optional: working directory, useful for monorepos # working-directory: somedir diff --git a/Makefile b/Makefile index f7227ac..a41d259 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ lint: tools .PHONY: tools tools: - @go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.50.1 + @go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest test-packages: go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) @@ -49,4 +49,4 @@ docker-start-postgres: -e POSTGRES_DB=${GOOSE_POSTGRES_DBNAME} \ -p ${GOOSE_POSTGRES_PORT}:5432 \ -l goose_test \ - postgres:14-alpine + postgres:14-alpine -c log_statement=all diff --git a/cmd/goose/main.go b/cmd/goose/main.go index bb03e36..740ff81 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -7,12 +7,18 @@ import ( "io/fs" "log" "os" + "path/filepath" "runtime/debug" + "sort" "strconv" + "strings" + "text/tabwriter" "text/template" "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/cfg" + "github.com/pressly/goose/v3/internal/migrationstats" + "github.com/pressly/goose/v3/internal/migrationstats/migrationstatsos" ) var ( @@ -95,6 +101,11 @@ func main() { fmt.Printf("%s=%q\n", env.Name, env.Value) } return + case "validate": + if err := printValidate(*dir, *verbose); err != nil { + log.Fatalf("goose validate: %v", err) + } + return } args = mergeArgs(args) @@ -278,3 +289,59 @@ func gooseInit(dir string) error { } return goose.CreateWithTemplate(nil, dir, sqlMigrationTemplate, "initial", "sql") } + +func gatherFilenames(filename string) ([]string, error) { + stat, err := os.Stat(filename) + if err != nil { + return nil, err + } + var filenames []string + if stat.IsDir() { + for _, pattern := range []string{"*.sql", "*.go"} { + file, err := filepath.Glob(filepath.Join(filename, pattern)) + if err != nil { + return nil, err + } + filenames = append(filenames, file...) + } + } else { + filenames = append(filenames, filename) + } + sort.Strings(filenames) + return filenames, nil +} + +func printValidate(filename string, verbose bool) error { + filenames, err := gatherFilenames(filename) + if err != nil { + return err + } + fileWalker := migrationstatsos.NewFileWalker(filenames...) + stats, err := migrationstats.GatherStats(fileWalker, false) + if err != nil { + return err + } + // TODO(mf): we should introduce a --debug flag, which allows printing + // more internal debug information and leave verbose for additional information. + if !verbose { + return nil + } + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', tabwriter.TabIndent) + fmtPattern := "%v\t%v\t%v\t%v\t%v\t\n" + fmt.Fprintf(w, fmtPattern, "Type", "Txn", "Up", "Down", "Name") + fmt.Fprintf(w, fmtPattern, "────", "───", "──", "────", "────") + for _, m := range stats { + txnStr := "✔" + if !m.Tx { + txnStr = "✘" + } + fmt.Fprintf(w, fmtPattern, + strings.TrimPrefix(filepath.Ext(m.FileName), "."), + txnStr, + m.UpCount, + m.DownCount, + filepath.Base(m.FileName), + ) + } + return w.Flush() +} diff --git a/internal/migrationstats/migration_go.go b/internal/migrationstats/migration_go.go new file mode 100644 index 0000000..0a15037 --- /dev/null +++ b/internal/migrationstats/migration_go.go @@ -0,0 +1,129 @@ +package migrationstats + +import ( + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io" +) + +const ( + registerGoFuncName = "AddMigration" + registerGoFuncNameNoTx = "AddMigrationNoTx" +) + +type goMigration struct { + name string + useTx *bool + upFuncName, downFuncName string +} + +func parseGoFile(r io.Reader) (*goMigration, error) { + astFile, err := parser.ParseFile( + token.NewFileSet(), + "", // filename + r, + // We don't need to resolve imports, so we can skip it. + // This speeds up the parsing process. + // See https://github.com/golang/go/issues/46485 + parser.SkipObjectResolution, + ) + if err != nil { + return nil, err + } + for _, decl := range astFile.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok || fn == nil || fn.Name == nil { + continue + } + if fn.Name.Name == "init" { + return parseInitFunc(fn) + } + } + return nil, errors.New("no init function") +} + +func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) { + if fd == nil { + return nil, fmt.Errorf("function declaration must not be nil") + } + if fd.Body == nil { + return nil, fmt.Errorf("no function body") + } + if len(fd.Body.List) == 0 { + return nil, fmt.Errorf("no registered goose functions") + } + gf := new(goMigration) + for _, statement := range fd.Body.List { + expr, ok := statement.(*ast.ExprStmt) + if !ok { + continue + } + call, ok := expr.X.(*ast.CallExpr) + if !ok { + continue + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || sel == nil { + continue + } + funcName := sel.Sel.Name + b := false + switch funcName { + case registerGoFuncName: + b = true + gf.useTx = &b + case registerGoFuncNameNoTx: + gf.useTx = &b + default: + continue + } + if gf.name != "" { + return nil, fmt.Errorf("found duplicate registered functions:\nprevious: %v\ncurrent: %v", gf.name, funcName) + } + gf.name = funcName + + if len(call.Args) != 2 { + return nil, fmt.Errorf("registered goose functions have 2 arguments: got %d", len(call.Args)) + } + getNameFromExpr := func(expr ast.Expr) (string, error) { + arg, ok := expr.(*ast.Ident) + if !ok { + return "", fmt.Errorf("failed to assert argument identifer: got %T", arg) + } + return arg.Name, nil + } + var err error + gf.upFuncName, err = getNameFromExpr(call.Args[0]) + if err != nil { + return nil, err + } + gf.downFuncName, err = getNameFromExpr(call.Args[1]) + if err != nil { + return nil, err + } + } + // validation + switch gf.name { + case registerGoFuncName, registerGoFuncNameNoTx: + default: + return nil, fmt.Errorf("goose register function must be one of: %s or %s", + registerGoFuncName, + registerGoFuncNameNoTx, + ) + } + if gf.useTx == nil { + return nil, errors.New("validation error: failed to identify transaction: got nil bool") + } + // The up and down functions can either be named Go functions or "nil", an + // empty string means there is a flaw in our parsing logic of the Go source code. + if gf.upFuncName == "" { + return nil, fmt.Errorf("validation error: up function is empty string") + } + if gf.downFuncName == "" { + return nil, fmt.Errorf("validation error: down function is empty string") + } + return gf, nil +} diff --git a/internal/migrationstats/migration_sql.go b/internal/migrationstats/migration_sql.go new file mode 100644 index 0000000..d17ab22 --- /dev/null +++ b/internal/migrationstats/migration_sql.go @@ -0,0 +1,47 @@ +package migrationstats + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + + "github.com/pressly/goose/v3/internal/sqlparser" +) + +type sqlMigration struct { + useTx bool + upCount, downCount int +} + +func parseSQLFile(r io.Reader, debug bool) (*sqlMigration, error) { + by, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + upStatements, txUp, err := sqlparser.ParseSQLMigration( + bytes.NewReader(by), + sqlparser.DirectionUp, + debug, + ) + if err != nil { + return nil, err + } + downStatements, txDown, err := sqlparser.ParseSQLMigration( + bytes.NewReader(by), + sqlparser.DirectionDown, + debug, + ) + if err != nil { + return nil, err + } + // This is a sanity check to ensure that the parser is behaving as expected. + if txUp != txDown { + return nil, fmt.Errorf("up and down statements must have the same transaction mode") + } + return &sqlMigration{ + useTx: txUp, + upCount: len(upStatements), + downCount: len(downStatements), + }, nil +} diff --git a/internal/migrationstats/migrationstats.go b/internal/migrationstats/migrationstats.go new file mode 100644 index 0000000..fdaa606 --- /dev/null +++ b/internal/migrationstats/migrationstats.go @@ -0,0 +1,78 @@ +package migrationstats + +import ( + "fmt" + "io" + "path/filepath" + + "github.com/pressly/goose/v3" +) + +// FileWalker walks all files for GatherStats. +type FileWalker interface { + // Walk invokes fn for each file. + Walk(fn func(filename string, r io.Reader) error) error +} + +// Stats contains the stats for a migration file. +type Stats struct { + // FileName is the name of the file. + FileName string + // Version is the version of the migration. + Version int64 + // Tx is true if the .sql migration file has a +goose NO TRANSACTION annotation + // or the .go migration file calls AddMigrationNoTx. + Tx bool + // UpCount is the number of statements in the Up migration. + UpCount int + // DownCount is the number of statements in the Down migration. + DownCount int +} + +// GatherStats returns the migration file stats. +func GatherStats(fw FileWalker, debug bool) ([]*Stats, error) { + var stats []*Stats + err := fw.Walk(func(filename string, r io.Reader) error { + version, err := goose.NumericComponent(filename) + if err != nil { + return fmt.Errorf("failed to get version from file %q: %w", filename, err) + } + var up, down int + var tx bool + switch filepath.Ext(filename) { + case ".sql": + m, err := parseSQLFile(r, debug) + if err != nil { + return fmt.Errorf("failed to parse file %q: %w", filename, err) + } + up, down = m.upCount, m.downCount + tx = m.useTx + case ".go": + m, err := parseGoFile(r) + if err != nil { + return fmt.Errorf("failed to parse file %q: %w", filename, err) + } + up, down = nilAsNumber(m.upFuncName), nilAsNumber(m.downFuncName) + tx = *m.useTx + } + stats = append(stats, &Stats{ + FileName: filename, + Version: version, + Tx: tx, + UpCount: up, + DownCount: down, + }) + return nil + }) + if err != nil { + return nil, err + } + return stats, nil +} + +func nilAsNumber(s string) int { + if s != "nil" { + return 1 + } + return 0 +} diff --git a/internal/migrationstats/migrationstats_test.go b/internal/migrationstats/migrationstats_test.go new file mode 100644 index 0000000..67a65a3 --- /dev/null +++ b/internal/migrationstats/migrationstats_test.go @@ -0,0 +1,181 @@ +package migrationstats + +import ( + "strings" + "testing" + + "github.com/pressly/goose/v3/internal/check" +) + +func TestParsingGoMigrations(t *testing.T) { + tests := []struct { + name string + input string + wantUpName, wantDownName string + wantTx bool + }{ + // AddMigration + {"upAndDown", upAndDown, "up001", "down001", true}, + {"downOnly", downOnly, "nil", "down002", true}, + {"upOnly", upOnly, "up003", "nil", true}, + {"upAndDownNil", upAndDownNil, "nil", "nil", true}, + // AddMigrationNoTx + {"upAndDownNoTx", upAndDownNoTx, "up001", "down001", false}, + {"downOnlyNoTx", downOnlyNoTx, "nil", "down002", false}, + {"upOnlyNoTx", upOnlyNoTx, "up003", "nil", false}, + {"upAndDownNilNoTx", upAndDownNilNoTx, "nil", "nil", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g, err := parseGoFile(strings.NewReader(tc.input)) + check.NoError(t, err) + check.Equal(t, g.useTx != nil, true) + check.Bool(t, *g.useTx, tc.wantTx) + check.Equal(t, g.downFuncName, tc.wantDownName) + check.Equal(t, g.upFuncName, tc.wantUpName) + }) + } +} + +func TestParsingGoMigrationsError(t *testing.T) { + _, err := parseGoFile(strings.NewReader(emptyInit)) + check.HasError(t, err) + check.Contains(t, err.Error(), "no registered goose functions") + + _, err = parseGoFile(strings.NewReader(wrongName)) + check.HasError(t, err) + check.Contains(t, err.Error(), "AddMigration or AddMigrationNoTx") +} + +var ( + upAndDown = `package foo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(up001, down001) +} + +func up001(tx *sql.Tx) error { return nil } + +func down001(tx *sql.Tx) error { return nil }` + + downOnly = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(nil, down002) +} + +func down002(tx *sql.Tx) error { return nil }` + + upOnly = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(up003, nil) +} + +func up003(tx *sql.Tx) error { return nil }` + + upAndDownNil = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(nil, nil) +}` +) +var ( + upAndDownNoTx = `package foo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(up001, down001) +} + +func up001(db *sql.DB) error { return nil } + +func down001(db *sql.DB) error { return nil }` + + downOnlyNoTx = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(nil, down002) +} + +func down002(db *sql.DB) error { return nil }` + + upOnlyNoTx = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(up003, nil) +} + +func up003(db *sql.DB) error { return nil }` + + upAndDownNilNoTx = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(nil, nil) +}` +) + +var ( + emptyInit = `package testgo + +func init() {}` + + wrongName = `package testgo + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationWrongName(nil, nil) +}` +) diff --git a/internal/migrationstats/migrationstatsos/migrationstatsos.go b/internal/migrationstats/migrationstatsos/migrationstatsos.go new file mode 100644 index 0000000..ed86914 --- /dev/null +++ b/internal/migrationstats/migrationstatsos/migrationstatsos.go @@ -0,0 +1,46 @@ +package migrationstatsos + +import ( + "io" + "os" + "path/filepath" + + "github.com/pressly/goose/v3/internal/migrationstats" +) + +// NewFileWalker returns a new FileWalker for the given filenames. +// +// Filenames without a .sql or .go extension are ignored. +func NewFileWalker(filenames ...string) migrationstats.FileWalker { + return &fileWalker{ + filenames: filenames, + } +} + +type fileWalker struct { + filenames []string +} + +var _ migrationstats.FileWalker = (*fileWalker)(nil) + +func (f *fileWalker) Walk(fn func(filename string, r io.Reader) error) error { + for _, filename := range f.filenames { + ext := filepath.Ext(filename) + if ext != ".sql" && ext != ".go" { + continue + } + if err := walk(filename, fn); err != nil { + return err + } + } + return nil +} + +func walk(filename string, fn func(filename string, r io.Reader) error) error { + file, err := os.Open(filename) + if err != nil { + return err + } + defer file.Close() + return fn(filename, file) +} diff --git a/internal/sqlparser/parser.go b/internal/sqlparser/parser.go index 9bb4263..369dd8d 100644 --- a/internal/sqlparser/parser.go +++ b/internal/sqlparser/parser.go @@ -89,7 +89,7 @@ var bufferPool = sync.Pool{ // within a statement. For these cases, we provide the explicit annotations // 'StatementBegin' and 'StatementEnd' to allow the script to // tell us to ignore semicolons. -func ParseSQLMigration(r io.Reader, direction Direction, verbose bool) (stmts []string, useTx bool, err error) { +func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []string, useTx bool, err error) { scanBufPtr := bufferPool.Get().(*[]byte) scanBuf := *scanBufPtr defer bufferPool.Put(scanBufPtr) @@ -97,13 +97,13 @@ func ParseSQLMigration(r io.Reader, direction Direction, verbose bool) (stmts [] scanner := bufio.NewScanner(r) scanner.Buffer(scanBuf, scanBufSize) - stateMachine := newStateMachine(start, verbose) + stateMachine := newStateMachine(start, debug) useTx = true var buf bytes.Buffer for scanner.Scan() { line := scanner.Text() - if verbose { + if debug { log.Println(line) } if stateMachine.get() == start && strings.TrimSpace(line) == "" {