From a52c60d6fb5b20ccd9d9f01f8d2d6255fb4f5941 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Tue, 12 Dec 2023 09:04:02 -0500 Subject: [PATCH] fix: add context-aware functions to goose validate (#662) --- cmd/goose/main.go | 7 +-- internal/migrationstats/migration_go.go | 23 ++++++---- .../migrationstats/migrationstats_test.go | 45 ++++++++++++++++++- ...ionstatsos.go => migrationstats_walker.go} | 8 ++-- 4 files changed, 66 insertions(+), 17 deletions(-) rename internal/migrationstats/{migrationstatsos/migrationstatsos.go => migrationstats_walker.go} (78%) diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 0b6f807..c842171 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -19,7 +19,6 @@ import ( "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/cfg" "github.com/pressly/goose/v3/internal/migrationstats" - "github.com/pressly/goose/v3/internal/migrationstats/migrationstatsos" ) var ( @@ -320,8 +319,10 @@ func printValidate(filename string, verbose bool) error { if err != nil { return err } - fileWalker := migrationstatsos.NewFileWalker(filenames...) - stats, err := migrationstats.GatherStats(fileWalker, false) + stats, err := migrationstats.GatherStats( + migrationstats.NewFileWalker(filenames...), + false, + ) if err != nil { return err } diff --git a/internal/migrationstats/migration_go.go b/internal/migrationstats/migration_go.go index ad8d57d..3509c27 100644 --- a/internal/migrationstats/migration_go.go +++ b/internal/migrationstats/migration_go.go @@ -7,11 +7,14 @@ import ( "go/parser" "go/token" "io" + "strings" ) const ( - registerGoFuncName = "AddMigration" - registerGoFuncNameNoTx = "AddMigrationNoTx" + registerGoFuncName = "AddMigration" + registerGoFuncNameNoTx = "AddMigrationNoTx" + registerGoFuncNameContext = "AddMigrationContext" + registerGoFuncNameNoTxContext = "AddMigrationNoTxContext" ) type goMigration struct { @@ -72,10 +75,10 @@ func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) { funcName := sel.Sel.Name b := false switch funcName { - case registerGoFuncName: + case registerGoFuncName, registerGoFuncNameContext: b = true gf.useTx = &b - case registerGoFuncNameNoTx: + case registerGoFuncNameNoTx, registerGoFuncNameNoTxContext: gf.useTx = &b default: continue @@ -107,11 +110,15 @@ func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) { } // validation switch gf.name { - case registerGoFuncName, registerGoFuncNameNoTx: + case registerGoFuncName, registerGoFuncNameNoTx, registerGoFuncNameContext, registerGoFuncNameNoTxContext: default: - return nil, fmt.Errorf("goose register function must be one of: %s or %s", - registerGoFuncName, - registerGoFuncNameNoTx, + return nil, fmt.Errorf("goose register function must be one of: %s", + strings.Join([]string{ + registerGoFuncName, + registerGoFuncNameNoTx, + registerGoFuncNameContext, + registerGoFuncNameNoTxContext, + }, ", "), ) } if gf.useTx == nil { diff --git a/internal/migrationstats/migrationstats_test.go b/internal/migrationstats/migrationstats_test.go index 26c49fd..0383072 100644 --- a/internal/migrationstats/migrationstats_test.go +++ b/internal/migrationstats/migrationstats_test.go @@ -1,6 +1,8 @@ package migrationstats import ( + "os" + "path/filepath" "strings" "testing" @@ -38,6 +40,47 @@ func TestParsingGoMigrations(t *testing.T) { } } +func TestGoMigrationStats(t *testing.T) { + t.Parallel() + + base := "../../tests/gomigrations/success/testdata" + all, err := os.ReadDir(base) + check.NoError(t, err) + check.Equal(t, len(all), 16) + files := make([]string, 0, len(all)) + for _, f := range all { + files = append(files, filepath.Join(base, f.Name())) + } + stats, err := GatherStats(NewFileWalker(files...), false) + check.NoError(t, err) + check.Equal(t, len(stats), 16) + checkGoStats(t, stats[0], "001_up_down.go", 1, 1, 1, true) + checkGoStats(t, stats[1], "002_up_only.go", 2, 1, 0, true) + checkGoStats(t, stats[2], "003_down_only.go", 3, 0, 1, true) + checkGoStats(t, stats[3], "004_empty.go", 4, 0, 0, true) + checkGoStats(t, stats[4], "005_up_down_no_tx.go", 5, 1, 1, false) + checkGoStats(t, stats[5], "006_up_only_no_tx.go", 6, 1, 0, false) + checkGoStats(t, stats[6], "007_down_only_no_tx.go", 7, 0, 1, false) + checkGoStats(t, stats[7], "008_empty_no_tx.go", 8, 0, 0, false) + checkGoStats(t, stats[8], "009_up_down_ctx.go", 9, 1, 1, true) + checkGoStats(t, stats[9], "010_up_only_ctx.go", 10, 1, 0, true) + checkGoStats(t, stats[10], "011_down_only_ctx.go", 11, 0, 1, true) + checkGoStats(t, stats[11], "012_empty_ctx.go", 12, 0, 0, true) + checkGoStats(t, stats[12], "013_up_down_no_tx_ctx.go", 13, 1, 1, false) + checkGoStats(t, stats[13], "014_up_only_no_tx_ctx.go", 14, 1, 0, false) + checkGoStats(t, stats[14], "015_down_only_no_tx_ctx.go", 15, 0, 1, false) + checkGoStats(t, stats[15], "016_empty_no_tx_ctx.go", 16, 0, 0, false) +} + +func checkGoStats(t *testing.T, stats *Stats, filename string, version int64, upCount, downCount int, tx bool) { + t.Helper() + check.Equal(t, filepath.Base(stats.FileName), filename) + check.Equal(t, stats.Version, version) + check.Equal(t, stats.UpCount, upCount) + check.Equal(t, stats.DownCount, downCount) + check.Equal(t, stats.Tx, tx) +} + func TestParsingGoMigrationsError(t *testing.T) { t.Parallel() _, err := parseGoFile(strings.NewReader(emptyInit)) @@ -46,7 +89,7 @@ func TestParsingGoMigrationsError(t *testing.T) { _, err = parseGoFile(strings.NewReader(wrongName)) check.HasError(t, err) - check.Contains(t, err.Error(), "AddMigration or AddMigrationNoTx") + check.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext") } var ( diff --git a/internal/migrationstats/migrationstatsos/migrationstatsos.go b/internal/migrationstats/migrationstats_walker.go similarity index 78% rename from internal/migrationstats/migrationstatsos/migrationstatsos.go rename to internal/migrationstats/migrationstats_walker.go index ed86914..4e224e6 100644 --- a/internal/migrationstats/migrationstatsos/migrationstatsos.go +++ b/internal/migrationstats/migrationstats_walker.go @@ -1,17 +1,15 @@ -package migrationstatsos +package migrationstats import ( "io" "os" "path/filepath" - - "github.com/pressly/goose/v3/internal/migrationstats" ) // NewFileWalker returns a new FileWalker for the given filenames. // // Filenames without a .sql or .go extension are ignored. -func NewFileWalker(filenames ...string) migrationstats.FileWalker { +func NewFileWalker(filenames ...string) FileWalker { return &fileWalker{ filenames: filenames, } @@ -21,7 +19,7 @@ type fileWalker struct { filenames []string } -var _ migrationstats.FileWalker = (*fileWalker)(nil) +var _ FileWalker = (*fileWalker)(nil) func (f *fileWalker) Walk(fn func(filename string, r io.Reader) error) error { for _, filename := range f.filenames {