diff --git a/go.mod b/go.mod index 0230e2c..8276042 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 go.uber.org/multierr v1.11.0 + golang.org/x/sync v0.4.0 modernc.org/sqlite v1.26.0 ) diff --git a/go.sum b/go.sum index e2c6ea0..537f343 100644 --- a/go.sum +++ b/go.sum @@ -184,7 +184,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/migrate/doc.go b/internal/migrate/doc.go deleted file mode 100644 index 5fbee15..0000000 --- a/internal/migrate/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Package migrate defines a Migration struct and implements the migration logic for executing Go -// and SQL migrations. -// -// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported. -// - For SQL migrations, all three are supported. -// -// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first -// time it is executed. -package migrate diff --git a/internal/migrate/migration.go b/internal/migrate/migration.go deleted file mode 100644 index 23a0514..0000000 --- a/internal/migrate/migration.go +++ /dev/null @@ -1,166 +0,0 @@ -package migrate - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/pressly/goose/v3/internal/sqlextended" -) - -type Migration struct { - // Fullpath is the full path to the migration file. - // - // Example: /path/to/migrations/123_create_users_table.go - Fullpath string - // Version is the version of the migration. - Version int64 - // Type is the type of migration. - Type MigrationType - // A migration is either a Go migration or a SQL migration, but never both. - // - // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is - // an optimization to avoid parsing the SQL migration if it is never required. Also, the - // majority of the time migrations are incremental, so it is likely that the user will only want - // to run the last few migrations, and there is no need to parse ALL prior migrations. - // - // Exactly one of these fields will be set: - Go *Go - // -- or -- - SQLParsed bool - SQL *SQL -} - -type MigrationType int - -const ( - TypeGo MigrationType = iota + 1 - TypeSQL -) - -func (t MigrationType) String() string { - switch t { - case TypeGo: - return "go" - case TypeSQL: - return "sql" - default: - // This should never happen. - return "unknown" - } -} - -func (m *Migration) UseTx() bool { - switch m.Type { - case TypeGo: - return m.Go.UseTx - case TypeSQL: - return m.SQL.UseTx - default: - // This should never happen. - panic("unknown migration type: use tx") - } -} - -func (m *Migration) IsEmpty(direction bool) bool { - switch m.Type { - case TypeGo: - return m.Go.IsEmpty(direction) - case TypeSQL: - return m.SQL.IsEmpty(direction) - default: - // This should never happen. - panic("unknown migration type: is empty") - } -} - -func (m *Migration) GetSQLStatements(direction bool) ([]string, error) { - if m.Type != TypeSQL { - return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type) - } - if m.SQL == nil { - return nil, errors.New("sql migration has not been initialized") - } - if !m.SQLParsed { - return nil, errors.New("sql migration has not been parsed") - } - if direction { - return m.SQL.UpStatements, nil - } - return m.SQL.DownStatements, nil -} - -type Go struct { - // We used an explicit bool instead of relying on a pointer because registered funcs may be nil. - // These are still valid Go and versioned migrations, but they are just empty. - // - // For example: goose.AddMigration(nil, nil) - UseTx bool - - // Only one of these func pairs will be set: - UpFn, DownFn func(context.Context, *sql.Tx) error - // -- or -- - UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error -} - -func (g *Go) IsEmpty(direction bool) bool { - if direction { - return g.UpFn == nil && g.UpFnNoTx == nil - } - return g.DownFn == nil && g.DownFnNoTx == nil -} - -func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error { - var fn func(context.Context, *sql.Tx) error - if direction { - fn = g.UpFn - } else { - fn = g.DownFn - } - if fn != nil { - return fn(ctx, tx) - } - return nil -} - -func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { - var fn func(context.Context, *sql.DB) error - if direction { - fn = g.UpFnNoTx - } else { - fn = g.DownFnNoTx - } - if fn != nil { - return fn(ctx, db) - } - return nil -} - -type SQL struct { - UseTx bool - UpStatements []string - DownStatements []string -} - -func (s *SQL) IsEmpty(direction bool) bool { - if direction { - return len(s.UpStatements) == 0 - } - return len(s.DownStatements) == 0 -} - -func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { - var statements []string - if direction { - statements = s.UpStatements - } else { - statements = s.DownStatements - } - for _, stmt := range statements { - if _, err := db.ExecContext(ctx, stmt); err != nil { - return err - } - } - return nil -} diff --git a/internal/migrate/parse.go b/internal/migrate/parse.go deleted file mode 100644 index 18a66b4..0000000 --- a/internal/migrate/parse.go +++ /dev/null @@ -1,75 +0,0 @@ -package migrate - -import ( - "bytes" - "io" - "io/fs" - - "github.com/pressly/goose/v3/internal/sqlparser" -) - -// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it -// will not be parsed again. -// -// Important: This function will mutate SQL migrations. -func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error { - for _, m := range migrations { - if m.Type == TypeSQL && !m.SQLParsed { - parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug) - if err != nil { - return err - } - m.SQLParsed = true - m.SQL = parsedSQLMigration - } - } - return nil -} - -// parse is used to determine which direction to parse the SQL migration. -type parse int - -const ( - // parseAll parses all SQL statements in BOTH directions. - parseAll parse = iota + 1 - // parseUp parses all SQL statements in the UP direction. - parseUp - // parseDown parses all SQL statements in the DOWN direction. - parseDown -) - -func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) { - r, err := fsys.Open(filename) - if err != nil { - return nil, err - } - by, err := io.ReadAll(r) - if err != nil { - return nil, err - } - if err := r.Close(); err != nil { - return nil, err - } - s := new(SQL) - if p == parseAll || p == parseUp { - s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration( - bytes.NewReader(by), - sqlparser.DirectionUp, - debug, - ) - if err != nil { - return nil, err - } - } - if p == parseAll || p == parseDown { - s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration( - bytes.NewReader(by), - sqlparser.DirectionDown, - debug, - ) - if err != nil { - return nil, err - } - } - return s, nil -} diff --git a/internal/provider/collect.go b/internal/provider/collect.go index cf12961..6658c80 100644 --- a/internal/provider/collect.go +++ b/internal/provider/collect.go @@ -9,15 +9,56 @@ import ( "strings" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/migrate" ) +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +type Source struct { + // Type is the type of migration. + Type MigrationType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +func newSource(t MigrationType, fullpath string, version int64) Source { + return Source{ + Type: t, + Fullpath: fullpath, + Version: version, + } +} + // fileSources represents a collection of migration files on the filesystem. type fileSources struct { sqlSources []Source goSources []Source } +func (s *fileSources) lookup(t MigrationType, version int64) *Source { + switch t { + case TypeGo: + for _, source := range s.goSources { + if source.Version == version { + return &source + } + } + case TypeSQL: + for _, source := range s.sqlSources { + if source.Version == version { + return &source + } + } + } + return nil +} + // collectFileSources scans the file system for migration files that have a numeric prefix (greater // than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil, // in which case an empty fileSources is returned. @@ -69,15 +110,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil } switch filepath.Ext(base) { case ".sql": - sources.sqlSources = append(sources.sqlSources, Source{ - Fullpath: fullpath, - Version: version, - }) + sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version)) case ".go": - sources.goSources = append(sources.goSources, Source{ - Fullpath: fullpath, - Version: version, - }) + sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version)) default: // Should never happen since we already filtered out all other file types. return nil, fmt.Errorf("unknown migration type: %s", base) @@ -89,19 +124,17 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil return sources, nil } -func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) { - var migrations []*migrate.Migration - migrationLookup := make(map[int64]*migrate.Migration) +func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) { + var migrations []*migration + migrationLookup := make(map[int64]*migration) // Add all SQL migrations to the list of migrations. - for _, s := range sources.sqlSources { - m := &migrate.Migration{ - Type: migrate.TypeSQL, - Fullpath: s.Fullpath, - Version: s.Version, - SQLParsed: false, + for _, source := range sources.sqlSources { + m := &migration{ + Source: source, + SQL: nil, // SQL migrations are parsed lazily. } migrations = append(migrations, m) - migrationLookup[s.Version] = m + migrationLookup[source.Version] = m } // If there are no Go files in the filesystem and no registered Go migrations, return early. if len(sources.goSources) == 0 && len(registerd) == 0 { @@ -127,38 +160,41 @@ func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migra // migrations may not have a corresponding file on disk. Which is fine! We include them // wholesale as part of migrations. This allows users to build a custom binary that only embeds // the SQL migration files. - for _, r := range registerd { + for version, r := range registerd { + var fullpath string + if s := sources.lookup(TypeGo, version); s != nil { + fullpath = s.Fullpath + } // Ensure there are no duplicate versions. - if existing, ok := migrationLookup[r.Version]; ok { + if existing, ok := migrationLookup[version]; ok { + if fullpath == "" { + fullpath = "manually registered (no source)" + } return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", - r.Version, - existing, - filepath.Base(r.Source), + version, + existing.Source.Fullpath, + fullpath, ) } - m := &migrate.Migration{ - Fullpath: r.Source, // May be empty if the migration was registered manually. - Version: r.Version, - Type: migrate.TypeGo, - Go: &migrate.Go{ - UseTx: r.UseTx, - UpFn: r.UpFnContext, - UpFnNoTx: r.UpFnNoTxContext, - DownFn: r.DownFnContext, - DownFnNoTx: r.DownFnNoTxContext, - }, + m := &migration{ + // Note, the fullpath may be empty if the migration was registered manually. + Source: newSource(TypeGo, fullpath, version), + Go: r, } migrations = append(migrations, m) - migrationLookup[r.Version] = m + migrationLookup[version] = m } // Sort migrations by version in ascending order. sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Version < migrations[j].Version + return migrations[i].Source.Version < migrations[j].Source.Version }) return migrations, nil } func unregisteredError(unregistered []string) error { + const ( + hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations" + ) f := "file" if len(unregistered) > 1 { f += "s" @@ -169,8 +205,9 @@ func unregisteredError(unregistered []string) error { for _, name := range unregistered { b.WriteString("\t" + name + "\n") } + hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL) + b.WriteString(hint) b.WriteString("\n") - b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations") return errors.New(b.String()) } diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index a5ee2d3..401a1ce 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -10,14 +10,14 @@ import ( func TestCollectFileSources(t *testing.T) { t.Parallel() - t.Run("nil", func(t *testing.T) { + t.Run("nil_fsys", func(t *testing.T) { sources, err := collectFileSources(nil, false, nil) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) - t.Run("empty", func(t *testing.T) { + t.Run("empty_fsys", func(t *testing.T) { sources, err := collectFileSources(fstest.MapFS{}, false, nil) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) @@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - {Fullpath: "00001_foo.sql", Version: 1}, - {Fullpath: "00002_bar.sql", Version: 2}, - {Fullpath: "00003_baz.sql", Version: 3}, - {Fullpath: "00110_qux.sql", Version: 110}, + newSource(TypeSQL, "00001_foo.sql", 1), + newSource(TypeSQL, "00002_bar.sql", 2), + newSource(TypeSQL, "00003_baz.sql", 3), + newSource(TypeSQL, "00110_qux.sql", 110), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - {Fullpath: "00001_foo.sql", Version: 1}, - {Fullpath: "00003_baz.sql", Version: 3}, + newSource(TypeSQL, "00001_foo.sql", 1), + newSource(TypeSQL, "00003_baz.sql", 3), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -159,18 +159,146 @@ func TestCollectFileSources(t *testing.T) { } } assertDirpath(".", []Source{ - {Fullpath: "876_a.sql", Version: 876}, + newSource(TypeSQL, "876_a.sql", 876), }) assertDirpath("dir1", []Source{ - {Fullpath: "101_a.sql", Version: 101}, - {Fullpath: "102_b.sql", Version: 102}, - {Fullpath: "103_c.sql", Version: 103}, + newSource(TypeSQL, "101_a.sql", 101), + newSource(TypeSQL, "102_b.sql", 102), + newSource(TypeSQL, "103_c.sql", 103), + }) + assertDirpath("dir2", []Source{ + newSource(TypeSQL, "201_a.sql", 201), }) - assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}}) assertDirpath("dir3", nil) }) } +func TestMerge(t *testing.T) { + t.Parallel() + + t.Run("with_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + // SQL + "migrations/00001_foo.sql": sqlMapFile, + // Go + "migrations/00002_bar.go": {Data: []byte(`package migrations`)}, + "migrations/00003_baz.go": {Data: []byte(`package migrations`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + check.Equal(t, len(sources.sqlSources), 1) + check.Equal(t, len(sources.goSources), 2) + src1 := sources.lookup(TypeSQL, 1) + check.Bool(t, src1 != nil, true) + src2 := sources.lookup(TypeGo, 2) + check.Bool(t, src2 != nil, true) + src3 := sources.lookup(TypeGo, 3) + check.Bool(t, src3 != nil, true) + + t.Run("valid", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + 2: {version: 2}, + 3: {version: 3}, + }) + check.NoError(t, err) + check.Number(t, len(migrations), 3) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) + }) + t.Run("unregistered_all", func(t *testing.T) { + _, err := merge(sources, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:") + check.Contains(t, err.Error(), "00002_bar.go") + check.Contains(t, err.Error(), "00003_baz.go") + }) + t.Run("unregistered_some", func(t *testing.T) { + _, err := merge(sources, map[int64]*goMigration{ + 2: {version: 2}, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") + check.Contains(t, err.Error(), "00003_baz.go") + }) + t.Run("duplicate_sql", func(t *testing.T) { + _, err := merge(sources, map[int64]*goMigration{ + 1: {version: 1}, // duplicate. SQL already exists. + 2: {version: 2}, + 3: {version: 3}, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "found duplicate migration version 1") + }) + }) + t.Run("no_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + // SQL + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.sql": sqlMapFile, + "migrations/00005_baz.sql": sqlMapFile, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + t.Run("unregistered_all", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + 3: {version: 3}, + // 4 is missing + 6: {version: 6}, + }) + check.NoError(t, err) + check.Number(t, len(migrations), 5) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5)) + assertMigration(t, migrations[4], newSource(TypeGo, "", 6)) + }) + }) + t.Run("partial_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + t.Run("unregistered_all", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + // This is the only Go file on disk. + 2: {version: 2}, + // These are not on disk. Explicitly registered. + 3: {version: 3}, + 6: {version: 6}, + }) + check.NoError(t, err) + check.Number(t, len(migrations), 4) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], newSource(TypeGo, "", 6)) + }) + }) +} + +func assertMigration(t *testing.T, got *migration, want Source) { + t.Helper() + check.Equal(t, got.Source, want) + switch got.Source.Type { + case TypeGo: + check.Equal(t, got.Go.version, want.Version) + case TypeSQL: + check.Bool(t, got.SQL == nil, true) + default: + t.Fatalf("unknown migration type: %s", got.Source.Type) + } +} + func newSQLOnlyFS() fstest.MapFS { return fstest.MapFS{ "migrations/00001_foo.sql": sqlMapFile, diff --git a/internal/provider/migration.go b/internal/provider/migration.go new file mode 100644 index 0000000..cf98abc --- /dev/null +++ b/internal/provider/migration.go @@ -0,0 +1,119 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +type migration struct { + Source Source + // A migration is either a Go migration or a SQL migration, but never both. + // + // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is + // an optimization to avoid parsing the SQL migration if it is never required. Also, the + // majority of the time migrations are incremental, so it is likely that the user will only want + // to run the last few migrations, and there is no need to parse ALL prior migrations. + // + // Exactly one of these fields will be set: + Go *goMigration + // -- OR -- + SQL *sqlMigration +} + +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return fmt.Sprintf("unknown (%d)", t) + } +} + +func (m *migration) GetSQLStatements(direction bool) ([]string, error) { + if m.Source.Type != TypeSQL { + return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Source.Type) + } + if m.SQL == nil { + return nil, errors.New("sql migration has not been parsed") + } + if direction { + return m.SQL.UpStatements, nil + } + return m.SQL.DownStatements, nil +} + +func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error { + if g == nil { + return nil + } + var fn func(context.Context, *sql.Tx) error + if direction && g.up != nil { + fn = g.up.Run + } + if !direction && g.down != nil { + fn = g.down.Run + } + if fn != nil { + return fn(ctx, tx) + } + return nil +} + +func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + if g == nil { + return nil + } + var fn func(context.Context, *sql.DB) error + if direction && g.up != nil { + fn = g.up.RunNoTx + } + if !direction && g.down != nil { + fn = g.down.RunNoTx + } + if fn != nil { + return fn(ctx, db) + } + return nil +} + +type sqlMigration struct { + UseTx bool + UpStatements []string + DownStatements []string +} + +func (s *sqlMigration) IsEmpty(direction bool) bool { + if direction { + return len(s.UpStatements) == 0 + } + return len(s.DownStatements) == 0 +} + +func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { + var statements []string + if direction { + statements = s.UpStatements + } else { + statements = s.DownStatements + } + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 6702f07..7d50850 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -4,12 +4,14 @@ import ( "context" "database/sql" "errors" + "fmt" "io/fs" "os" "time" - "github.com/pressly/goose/v3/internal/migrate" + "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/sqlparser" ) var ( @@ -17,6 +19,8 @@ var ( ErrNoMigrations = errors.New("no migrations found") ) +var registeredGoMigrations = make(map[int64]*goose.Migration) + // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For @@ -68,7 +72,59 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if err != nil { return nil, err } - migrations, err := merge(sources, nil) + // + // TODO(mf): move the merging of Go migrations into a separate function. + // + registered := make(map[int64]*goMigration) + // Add user-registered Go migrations. + for version, m := range cfg.registered { + registered[version] = &goMigration{ + version: version, + up: m.up, + down: m.down, + } + } + // Add init() functions. This is a bit ugly because we need to convert from the old Migration + // struct to the new goMigration struct. + for version, m := range registeredGoMigrations { + if _, ok := registered[version]; ok { + return nil, fmt.Errorf("go migration with version %d already registered", version) + } + g := &goMigration{ + version: version, + } + if m == nil { + return nil, errors.New("registered migration with nil init function") + } + if m.UpFnContext != nil && m.UpFnNoTxContext != nil { + return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext") + } + if m.DownFnContext != nil && m.DownFnNoTxContext != nil { + return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext") + } + // Up + if m.UpFnContext != nil { + g.up = &GoMigration{ + Run: m.UpFnContext, + } + } else if m.UpFnNoTxContext != nil { + g.up = &GoMigration{ + RunNoTx: m.UpFnNoTxContext, + } + } + // Down + if m.DownFnContext != nil { + g.down = &GoMigration{ + Run: m.DownFnContext, + } + } else if m.DownFnNoTxContext != nil { + g.down = &GoMigration{ + RunNoTx: m.DownFnNoTxContext, + } + } + registered[version] = g + } + migrations, err := merge(sources, registered) if err != nil { return nil, err } @@ -98,7 +154,7 @@ type Provider struct { fsys fs.FS cfg config store sqladapter.Store - migrations []*migrate.Migration + migrations []*migration } // State represents the state of a migration. @@ -137,48 +193,12 @@ func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { return 0, errors.New("not implemented") } -// SourceType represents the type of migration source. -type SourceType string - -const ( - // SourceTypeSQL represents a SQL migration. - SourceTypeSQL SourceType = "sql" - // SourceTypeGo represents a Go migration. - SourceTypeGo SourceType = "go" -) - -// Source represents a single migration source. -// -// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if -// the migration has a corresponding file on disk. It will be empty if the migration was registered -// manually. -type Source struct { - // Type is the type of migration. - Type SourceType - // Full path to the migration file. - // - // Example: /path/to/migrations/001_create_users_table.sql - Fullpath string - // Version is the version of the migration. - Version int64 -} - // ListSources returns a list of all available migration sources the provider is aware of, sorted in // ascending order by version. func (p *Provider) ListSources() []*Source { sources := make([]*Source, 0, len(p.migrations)) for _, m := range p.migrations { - s := &Source{ - Fullpath: m.Fullpath, - Version: m.Version, - } - switch m.Type { - case migrate.TypeSQL: - s.Type = SourceTypeSQL - case migrate.TypeGo: - s.Type = SourceTypeGo - } - sources = append(sources, s) + sources = append(sources, &m.Source) } return sources } @@ -240,3 +260,25 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } + +// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations and is not safe for concurrent use. +func ParseSQL(fsys fs.FS, debug bool, migrations []*migration) error { + for _, m := range migrations { + // If the migration is a SQL migration, and it has not been parsed, parse it. + if m.Source.Type == TypeSQL && m.SQL == nil { + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) + if err != nil { + return err + } + m.SQL = &sqlMigration{ + UseTx: parsed.UseTx, + UpStatements: parsed.Up, + DownStatements: parsed.Down, + } + } + } + return nil +} diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index d8060c4..f3ed15b 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -1,6 +1,8 @@ package provider import ( + "context" + "database/sql" "errors" "fmt" @@ -72,11 +74,67 @@ func WithExcludes(excludes []string) ProviderOption { }) } +// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration]. +type GoMigration struct { + // One of the following must be set: + Run func(context.Context, *sql.Tx) error + // -- OR -- + RunNoTx func(context.Context, *sql.DB) error +} + +// WithGoMigration registers a Go migration with the given version. +// +// If WithGoMigration is called multiple times with the same version, an error is returned. Both up +// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set. +func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { + return configFunc(func(c *config) error { + if version < 1 { + return fmt.Errorf("go migration version must be greater than 0") + } + if _, ok := c.registered[version]; ok { + return fmt.Errorf("go migration with version %d already registered", version) + } + // Allow nil up/down functions. This enables users to apply "no-op" migrations, while + // versioning them. + if up != nil { + if up.Run == nil && up.RunNoTx == nil { + return fmt.Errorf("go migration with version %d must have an up function", version) + } + if up.Run != nil && up.RunNoTx != nil { + return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version) + } + } + if down != nil { + if down.Run == nil && down.RunNoTx == nil { + return fmt.Errorf("go migration with version %d must have a down function", version) + } + if down.Run != nil && down.RunNoTx != nil { + return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version) + } + } + c.registered[version] = &goMigration{ + version: version, + up: up, + down: down, + } + return nil + }) +} + +type goMigration struct { + version int64 + up, down *GoMigration +} + type config struct { tableName string verbose bool excludes map[string]bool + // Go migrations registered by the user. These will be merged/resolved with migrations from the + // filesystem and init() functions. + registered map[int64]*goMigration + // Locking options lockEnabled bool sessionLocker lock.SessionLocker diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 10aed48..c8b5eff 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -36,11 +36,11 @@ func TestProvider(t *testing.T) { // 1 check.Equal(t, sources[0].Version, int64(1)) check.Equal(t, sources[0].Fullpath, "001_foo.sql") - check.Equal(t, sources[0].Type, provider.SourceTypeSQL) + check.Equal(t, sources[0].Type, provider.TypeSQL) // 2 check.Equal(t, sources[1].Version, int64(2)) check.Equal(t, sources[1].Fullpath, "002_bar.sql") - check.Equal(t, sources[1].Type, provider.SourceTypeSQL) + check.Equal(t, sources[1].Type, provider.TypeSQL) } var ( diff --git a/internal/migrate/run.go b/internal/provider/run.go similarity index 71% rename from internal/migrate/run.go rename to internal/provider/run.go index 7b7a883..f5ca250 100644 --- a/internal/migrate/run.go +++ b/internal/provider/run.go @@ -1,4 +1,4 @@ -package migrate +package provider import ( "context" @@ -8,10 +8,10 @@ import ( ) // Run runs the migration inside of a transaction. -func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { - switch m.Type { +func (m *migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Source.Type { case TypeSQL: - if m.SQL == nil || !m.SQLParsed { + if m.SQL == nil { return fmt.Errorf("tx: sql migration has not been parsed") } return m.SQL.run(ctx, tx, direction) @@ -19,14 +19,14 @@ func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { return m.Go.run(ctx, tx, direction) } // This should never happen. - return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } // RunNoTx runs the migration without a transaction. -func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { - switch m.Type { +func (m *migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Source.Type { case TypeSQL: - if m.SQL == nil || !m.SQLParsed { + if m.SQL == nil { return fmt.Errorf("db: sql migration has not been parsed") } return m.SQL.run(ctx, db, direction) @@ -34,14 +34,14 @@ func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) err return m.Go.runNoTx(ctx, db, direction) } // This should never happen. - return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } // RunConn runs the migration without a transaction using the provided connection. -func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { - switch m.Type { +func (m *migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Source.Type { case TypeSQL: - if m.SQL == nil || !m.SQLParsed { + if m.SQL == nil { return fmt.Errorf("conn: sql migration has not been parsed") } return m.SQL.run(ctx, conn, direction) @@ -49,5 +49,5 @@ func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") } // This should never happen. - return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } diff --git a/internal/sqlparser/parse.go b/internal/sqlparser/parse.go new file mode 100644 index 0000000..e993587 --- /dev/null +++ b/internal/sqlparser/parse.go @@ -0,0 +1,54 @@ +package sqlparser + +import ( + "io/fs" + + "go.uber.org/multierr" + "golang.org/x/sync/errgroup" +) + +type ParsedSQL struct { + UseTx bool + Up, Down []string +} + +func ParseAllFromFS(fsys fs.FS, filename string, debug bool) (*ParsedSQL, error) { + parsedSQL := new(ParsedSQL) + // TODO(mf): parse is called twice, once for up and once for down. This is inefficient. It + // should be possible to parse both directions in one pass. Also, UseTx is set once (but + // returned twice), which is unnecessary and potentially error-prone if the two calls to + // parseSQL disagree based on direction. + var g errgroup.Group + g.Go(func() error { + up, useTx, err := parse(fsys, filename, DirectionUp, debug) + if err != nil { + return err + } + parsedSQL.Up = up + parsedSQL.UseTx = useTx + return nil + }) + g.Go(func() error { + down, _, err := parse(fsys, filename, DirectionDown, debug) + if err != nil { + return err + } + parsedSQL.Down = down + return nil + }) + if err := g.Wait(); err != nil { + return nil, err + } + return parsedSQL, nil +} + +func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []string, _ bool, retErr error) { + r, err := fsys.Open(filename) + if err != nil { + return nil, false, err + } + defer func() { + retErr = multierr.Append(retErr, r.Close()) + }() + return ParseSQLMigration(r, direction, debug) +} diff --git a/internal/sqlparser/parse_test.go b/internal/sqlparser/parse_test.go new file mode 100644 index 0000000..632bbe1 --- /dev/null +++ b/internal/sqlparser/parse_test.go @@ -0,0 +1,82 @@ +package sqlparser_test + +import ( + "errors" + "os" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqlparser" +) + +func TestParseAllFromFS(t *testing.T) { + t.Parallel() + t.Run("file_not_exist", func(t *testing.T) { + mapFS := fstest.MapFS{} + _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.HasError(t, err) + check.Bool(t, errors.Is(err, os.ErrNotExist), true) + }) + t.Run("empty_file", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": &fstest.MapFile{}, + } + _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.HasError(t, err) + check.Contains(t, err.Error(), "failed to parse migration") + check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation") + }) + t.Run("all_statements", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": newFile(` +-- +goose Up +`), + "002_bar.sql": newFile(` +-- +goose Up +-- +goose Down +`), + "003_baz.sql": newFile(` +-- +goose Up +CREATE TABLE foo (id int); +CREATE TABLE bar (id int); + +-- +goose Down +DROP TABLE bar; +`), + "004_qux.sql": newFile(` +-- +goose NO TRANSACTION +-- +goose Up +CREATE TABLE foo (id int); +-- +goose Down +DROP TABLE foo; +`), + } + parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 0, 0) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 0, 0) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 2, 1) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, false, 1, 1) + }) +} + +func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) { + t.Helper() + check.Bool(t, got != nil, true) + check.Equal(t, len(got.Up), up) + check.Equal(t, len(got.Down), down) + check.Equal(t, got.UseTx, useTx) +} + +func newFile(data string) *fstest.MapFile { + return &fstest.MapFile{ + Data: []byte(data), + } +}