diff --git a/internal/check/check.go b/internal/check/check.go index 76dfac7..f5d1b6d 100644 --- a/internal/check/check.go +++ b/internal/check/check.go @@ -8,6 +8,13 @@ import ( "testing" ) +func NotNil(t *testing.T, v any) { + t.Helper() + if v == nil { + t.Fatal("unexpected nil value") + } +} + func NoError(t *testing.T, err error) { t.Helper() if err != nil { diff --git a/internal/provider/collect.go b/internal/provider/collect.go index fd7d63e..a4a73c0 100644 --- a/internal/provider/collect.go +++ b/internal/provider/collect.go @@ -13,9 +13,9 @@ import ( func NewSource(t MigrationType, fullpath string, version int64) Source { return Source{ - Type: t, - Fullpath: fullpath, - Version: version, + Type: t, + Path: fullpath, + Version: version, } } @@ -133,7 +133,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration var unregistered []string for _, s := range sources.goSources { if _, ok := registerd[s.Version]; !ok { - unregistered = append(unregistered, s.Fullpath) + unregistered = append(unregistered, s.Path) } } if len(unregistered) > 0 { @@ -149,7 +149,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration fullpath := r.fullpath if fullpath == "" { if s := sources.lookup(TypeGo, version); s != nil { - fullpath = s.Fullpath + fullpath = s.Path } } // Ensure there are no duplicate versions. @@ -160,7 +160,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration } return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", version, - existing.Source.Fullpath, + existing.Source.Path, fullpath, ) } diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index 73b2642..8417e84 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -6,6 +6,7 @@ import ( "testing/fstest" "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqladapter" ) func TestCollectFileSources(t *testing.T) { @@ -120,13 +121,13 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.sqlSources), 2) check.Number(t, len(sources.goSources), 1) // 1 - check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql") + check.Equal(t, sources.sqlSources[0].Path, "1_foo.sql") check.Equal(t, sources.sqlSources[0].Version, int64(1)) // 2 - check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql") + check.Equal(t, sources.sqlSources[1].Path, "5_qux.sql") check.Equal(t, sources.sqlSources[1].Version, int64(5)) // 3 - check.Equal(t, sources.goSources[0].Fullpath, "4_something.go") + check.Equal(t, sources.goSources[0].Path, "4_something.go") check.Equal(t, sources.goSources[0].Version, int64(4)) }) t.Run("duplicate_versions", func(t *testing.T) { @@ -286,6 +287,65 @@ func TestMerge(t *testing.T) { }) } +func TestFindMissingMigrations(t *testing.T) { + t.Parallel() + + t.Run("db_has_max_version", func(t *testing.T) { + // Test case: database has migrations 1, 3, 4, 5, 7 + // Missing migrations: 2, 6 + // Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8 + dbMigrations := []*sqladapter.ListMigrationsResult{ + {Version: 1}, + {Version: 3}, + {Version: 4}, + {Version: 5}, + {Version: 7}, // <-- database max version_id + } + fsMigrations := []*migration{ + newMigration(1), + newMigration(2), // missing migration + newMigration(3), + newMigration(4), + newMigration(5), + newMigration(6), // missing migration + newMigration(7), // ----- database max version_id ----- + newMigration(8), // new migration + } + got := findMissingMigrations(dbMigrations, fsMigrations) + check.Number(t, len(got), 2) + check.Number(t, got[0].versionID, 2) + check.Number(t, got[1].versionID, 6) + + // Sanity check. + check.Number(t, len(findMissingMigrations(nil, nil)), 0) + check.Number(t, len(findMissingMigrations(dbMigrations, nil)), 0) + check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0) + }) + t.Run("fs_has_max_version", func(t *testing.T) { + dbMigrations := []*sqladapter.ListMigrationsResult{ + {Version: 1}, + {Version: 5}, + {Version: 2}, + } + fsMigrations := []*migration{ + newMigration(3), // new migration + newMigration(4), // new migration + } + got := findMissingMigrations(dbMigrations, fsMigrations) + check.Number(t, len(got), 2) + check.Number(t, got[0].versionID, 3) + check.Number(t, got[1].versionID, 4) + }) +} + +func newMigration(version int64) *migration { + return &migration{ + Source: Source{ + Version: version, + }, + } +} + func assertMigration(t *testing.T, got *migration, want Source) { t.Helper() check.Equal(t, got.Source, want) diff --git a/internal/provider/errors.go b/internal/provider/errors.go index e8ece38..16cdd3f 100644 --- a/internal/provider/errors.go +++ b/internal/provider/errors.go @@ -22,18 +22,19 @@ var ( // PartialError is returned when a migration fails, but some migrations already got applied. type PartialError struct { - // Applied are migrations that were applied successfully before the error occurred. + // Applied are migrations that were applied successfully before the error occurred. May be + // empty. Applied []*MigrationResult - // Failed contains the result of the migration that failed. + // Failed contains the result of the migration that failed. Cannot be nil. Failed *MigrationResult - // Err is the error that occurred while running the migration. + // Err is the error that occurred while running the migration and caused the failure. Err error } func (e *PartialError) Error() string { filename := "(file unknown)" - if e.Failed != nil && e.Failed.Source.Fullpath != "" { - filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath)) + if e.Failed != nil && e.Failed.Source.Path != "" { + filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path)) } return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err) } diff --git a/internal/provider/migration.go b/internal/provider/migration.go index 87098cf..05faf01 100644 --- a/internal/provider/migration.go +++ b/internal/provider/migration.go @@ -29,7 +29,7 @@ func (m *migration) useTx(direction bool) bool { case TypeSQL: return m.SQL.UseTx case TypeGo: - if m.Go == nil { + if m.Go == nil || m.Go.isEmpty(direction) { return false } if direction { @@ -41,8 +41,18 @@ func (m *migration) useTx(direction bool) bool { return false } +func (m *migration) isEmpty(direction bool) bool { + switch m.Source.Type { + case TypeSQL: + return m.SQL == nil || m.SQL.IsEmpty(direction) + case TypeGo: + return m.Go == nil || m.Go.isEmpty(direction) + } + return true +} + func (m *migration) filename() string { - return filepath.Base(m.Source.Fullpath) + return filepath.Base(m.Source.Path) } // run runs the migration inside of a transaction. @@ -57,7 +67,7 @@ func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error { return m.Go.run(ctx, tx, direction) } // This should never happen. - return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path)) } // runNoTx runs the migration without a transaction. @@ -72,7 +82,7 @@ func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) err return m.Go.runNoTx(ctx, db, direction) } // This should never happen. - return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path)) } // runConn runs the migration without a transaction using the provided connection. @@ -87,7 +97,7 @@ func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") } // This should never happen. - return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path)) } type goMigration struct { @@ -95,6 +105,16 @@ type goMigration struct { up, down *GoMigration } +func (g *goMigration) isEmpty(direction bool) bool { + if g.up == nil && g.down == nil { + panic("go migration has no up or down") + } + if direction { + return g.up == nil + } + return g.down == nil +} + func newGoMigration(fullpath string, up, down *GoMigration) *goMigration { return &goMigration{ fullpath: fullpath, diff --git a/internal/provider/misc.go b/internal/provider/misc.go index be84b46..717edff 100644 --- a/internal/provider/misc.go +++ b/internal/provider/misc.go @@ -8,24 +8,48 @@ import ( ) type Migration struct { - Version int64 - Source string // path to .sql script or go file - Registered bool - UseTx bool - UpFnContext func(context.Context, *sql.Tx) error - DownFnContext func(context.Context, *sql.Tx) error - - UpFnNoTxContext func(context.Context, *sql.DB) error - DownFnNoTxContext func(context.Context, *sql.DB) error + Version int64 + Source string // path to .sql script or go file + Registered bool + UpFnContext, DownFnContext func(context.Context, *sql.Tx) error + UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error } var registeredGoMigrations = make(map[int64]*Migration) +// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of +// the migrations are nil or if a migration with the same version has already been registered. +// +// Not safe for concurrent use. func SetGlobalGoMigrations(migrations []*Migration) error { for _, m := range migrations { if m == nil { return errors.New("cannot register nil go migration") } + if m.Version < 1 { + return errors.New("migration versions must be greater than zero") + } + if !m.Registered { + return errors.New("migration must be registered") + } + if m.Source != "" { + // If the source is set, expect it to be a file path with a numeric component that + // matches the version. + version, err := NumericComponent(m.Source) + if err != nil { + return err + } + if version != m.Version { + return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source) + } + } + // It's valid for all of these to be nil. + if m.UpFnContext != nil && m.UpFnNoTxContext != nil { + return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext") + } + if m.DownFnContext != nil && m.DownFnNoTxContext != nil { + return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext") + } if _, ok := registeredGoMigrations[m.Version]; ok { return fmt.Errorf("go migration with version %d already registered", m.Version) } @@ -34,6 +58,9 @@ func SetGlobalGoMigrations(migrations []*Migration) error { return nil } +// ResetGlobalGoMigrations resets the global go migrations registry. +// +// Not safe for concurrent use. func ResetGlobalGoMigrations() { registeredGoMigrations = make(map[int64]*Migration) } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 3982ac3..2dd3350 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -19,8 +19,9 @@ import ( // github.com/lib/pq or github.com/jackc/pgx. // // fsys is the filesystem used to read the migration files, but may be nil. Most users will want to -// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is -// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub. +// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem. +// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out +// migrations using [fs.Sub]. // // See [ProviderOption] for more information on configuring the provider. // @@ -39,6 +40,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption } cfg := config{ registered: make(map[int64]*goMigration), + excludes: make(map[string]bool), } for _, opt := range opts { if err := opt.apply(&cfg); err != nil { @@ -133,10 +135,12 @@ type Provider struct { // database. mu sync.Mutex - db *sql.DB - fsys fs.FS - cfg config - store sqladapter.Store + db *sql.DB + fsys fs.FS + cfg config + store sqladapter.Store + + // migrations are ordered by version in ascending order. migrations []*migration } @@ -149,8 +153,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { // GetDBVersion returns the max version from the database, regardless of the applied order. For // example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been // applied, it returns 0. -// -// TODO(mf): this is not true? func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { return p.getDBVersion(ctx) } @@ -175,25 +177,28 @@ func (p *Provider) Close() error { return p.db.Close() } -// ApplyVersion applies exactly one migration at the specified version. If there is no source for -// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been -// applied already, this method returns [ErrAlreadyApplied]. +// ApplyVersion applies exactly one migration by version. If there is no source for the specified +// version, this method returns [ErrVersionNotFound]. If the migration has been applied already, +// this method returns [ErrAlreadyApplied]. // // When direction is true, the up migration is executed, and when direction is false, the down // migration is executed. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { + if version < 1 { + return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version) + } return p.apply(ctx, version, direction) } -// Up applies all pending migrations. If there are no new migrations to apply, this method returns -// empty list and nil error. +// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method +// returns empty list and nil error. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { return p.up(ctx, false, math.MaxInt64) } // UpByOne applies the next available migration. If there are no migrations to apply, this method // returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. -func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) { +func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { res, err := p.up(ctx, true, math.MaxInt64) if err != nil { return nil, err @@ -201,21 +206,28 @@ func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) { if len(res) == 0 { return nil, ErrNoNextVersion } - return res, nil + // This should never happen. We should always have exactly one result and test for this. + if len(res) > 1 { + return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res)) + } + return res[0], nil } -// UpTo applies all available migrations up to and including the specified version. If there are no -// migrations to apply, this method returns empty list and nil error. +// UpTo applies all available migrations up to, and including, the specified version. If there are +// no migrations to apply, this method returns empty list and nil error. // // For instance, if there are three new migrations (9,10,11) and the current database version is 8 -// with a requested version of 10, only versions 9 and 10 will be applied. +// with a requested version of 10, only versions 9,10 will be applied. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + if version < 1 { + return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version) + } return p.up(ctx, false, version) } // Down rolls back the most recently applied migration. If there are no migrations to apply, this // method returns [ErrNoNextVersion]. -func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) { +func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { res, err := p.down(ctx, true, 0) if err != nil { return nil, err @@ -223,16 +235,19 @@ func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) { if len(res) == 0 { return nil, ErrNoNextVersion } - return res, nil + if len(res) > 1 { + return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res)) + } + return res[0], nil } -// DownTo rolls back all migrations down to but not including the specified version. +// DownTo rolls back all migrations down to, but not including, the specified version. // -// For instance, if the current database version is 11, and the requested version is 9, only -// migrations 11 and 10 will be rolled back. +// For instance, if the current database version is 11,10,9... and the requested version is 9, only +// migrations 11, 10 will be rolled back. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { if version < 0 { - return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version) + return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version) } return p.down(ctx, false, version) } diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index 0b7cd7a..5f069d4 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -10,6 +10,9 @@ import ( ) const ( + // DefaultTablename is the default name of the database table used to track history of applied + // migrations. It can be overridden using the [WithTableName] option when creating a new + // provider. DefaultTablename = "goose_db_version" ) @@ -85,7 +88,8 @@ type GoMigration struct { // WithGoMigration registers a Go migration with the given version. // // If WithGoMigration is called multiple times with the same version, an error is returned. Both up -// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set. +// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be +// set. func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { return configFunc(func(c *config) error { if version < 1 { @@ -122,9 +126,10 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { // WithAllowMissing allows the provider to apply missing (out-of-order) migrations. // -// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is -// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied -// migrations will be: 1,6,2,3,5. +// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true, +// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of +// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed +// by new migrations. func WithAllowMissing(b bool) ProviderOption { return configFunc(func(c *config) error { c.allowMissing = b @@ -132,9 +137,9 @@ func WithAllowMissing(b bool) ProviderOption { }) } -// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations -// without tracking the versions in the database schema table. Useful for tests, seeding a database -// or running ad-hoc queries. +// WithNoVersioning disables versioning. Disabling versioning allows applying migrations without +// tracking the versions in the database schema table. Useful for tests, seeding a database or +// running ad-hoc queries. func WithNoVersioning(b bool) ProviderOption { return configFunc(func(c *config) error { c.noVersioning = b diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index ac4ec7e..6cd7a5f 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -41,7 +41,7 @@ func TestProvider(t *testing.T) { // Not parallel because it modifies global state. register := []*provider.Migration{ { - Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + Version: 1, Source: "00001_users_table.go", Registered: true, UpFnContext: nil, DownFnContext: nil, }, @@ -69,32 +69,46 @@ func TestProvider(t *testing.T) { t.Run("duplicate_up", func(t *testing.T) { err := provider.SetGlobalGoMigrations([]*provider.Migration{ { - Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + Version: 1, Source: "00001_users_table.go", Registered: true, UpFnContext: func(context.Context, *sql.Tx) error { return nil }, UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, }, }) - check.NoError(t, err) t.Cleanup(provider.ResetGlobalGoMigrations) - db := newDB(t) - _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) check.HasError(t, err) - check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext") + check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") }) t.Run("duplicate_down", func(t *testing.T) { err := provider.SetGlobalGoMigrations([]*provider.Migration{ { - Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + Version: 1, Source: "00001_users_table.go", Registered: true, DownFnContext: func(context.Context, *sql.Tx) error { return nil }, DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, }, }) - check.NoError(t, err) t.Cleanup(provider.ResetGlobalGoMigrations) - db := newDB(t) - _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) check.HasError(t, err) - check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext") + check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") + }) + t.Run("not_registered", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", + }, + }) + t.Cleanup(provider.ResetGlobalGoMigrations) + check.HasError(t, err) + check.Contains(t, err.Error(), "migration must be registered") + }) + t.Run("zero_not_allowed", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 0, + }, + }) + t.Cleanup(provider.ResetGlobalGoMigrations) + check.HasError(t, err) + check.Contains(t, err.Error(), "migration versions must be greater than zero") }) } diff --git a/internal/provider/run.go b/internal/provider/run.go index 55bef9f..6dc6001 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -15,9 +15,153 @@ import ( "go.uber.org/multierr" ) +var ( + errMissingZeroVersion = errors.New("missing zero version migration") +) + +func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + if len(p.migrations) == 0 { + return nil, nil + } + var apply []*migration + if p.cfg.noVersioning { + apply = p.migrations + } else { + // optimize(mf): Listing all migrations from the database isn't great. This is only required to + // support the allow missing (out-of-order) feature. For users that don't use this feature, we + // could just query the database for the current max version and then apply migrations greater + // than that version. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if len(dbMigrations) == 0 { + return nil, errMissingZeroVersion + } + apply, err = p.resolveUpMigrations(dbMigrations, version) + if err != nil { + return nil, err + } + } + // feat(mf): this is where can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. + // https://github.com/pressly/goose/issues/222 + // + // Careful, we can't use a single transaction for all migrations because some may have to be run + // in their own transaction. + return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne) +} + +func (p *Provider) resolveUpMigrations( + dbVersions []*sqladapter.ListMigrationsResult, + version int64, +) ([]*migration, error) { + var apply []*migration + var dbMaxVersion int64 + // dbAppliedVersions is a map of all applied migrations in the database. + dbAppliedVersions := make(map[int64]bool, len(dbVersions)) + for _, m := range dbVersions { + dbAppliedVersions[m.Version] = true + if m.Version > dbMaxVersion { + dbMaxVersion = m.Version + } + } + missingMigrations := findMissingMigrations(dbVersions, p.migrations) + // feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing + // migrations entirely. At the moment this is not supported, but leaving this comment because + // that's where that logic would be handled. + // + // For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. + // Not sure if this is a common use case, but it's possible. + if len(missingMigrations) > 0 && !p.cfg.allowMissing { + var collected []string + for _, v := range missingMigrations { + collected = append(collected, v.filename) + } + msg := "migration" + if len(collected) > 1 { + msg += "s" + } + return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]", + len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","), + ) + } + for _, v := range missingMigrations { + m, err := p.getMigration(v.versionID) + if err != nil { + return nil, err + } + apply = append(apply, m) + } + // filter all migrations with a version greater than the supplied version (min) and less than or + // equal to the requested version (max). Skip any migrations that have already been applied. + for _, m := range p.migrations { + if dbAppliedVersions[m.Source.Version] { + continue + } + if m.Source.Version > dbMaxVersion && m.Source.Version <= version { + apply = append(apply, m) + } + } + return apply, nil +} + +func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + if len(p.migrations) == 0 { + return nil, nil + } + if p.cfg.noVersioning { + downMigrations := p.migrations + if downByOne { + last := p.migrations[len(p.migrations)-1] + downMigrations = []*migration{last} + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + } + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if len(dbMigrations) == 0 { + return nil, errMissingZeroVersion + } + if dbMigrations[0].Version == 0 { + return nil, nil + } + var downMigrations []*migration + for _, dbMigration := range dbMigrations { + if dbMigration.Version <= version { + break + } + m, err := p.getMigration(dbMigration.Version) + if err != nil { + return nil, err + } + downMigrations = append(downMigrations, m) + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) +} + // runMigrations runs migrations sequentially in the given direction. // -// If the migrations slice is empty, this function returns nil with no error. +// If the migrations list is empty, return nil without error. func (p *Provider) runMigrations( ctx context.Context, conn *sql.Conn, @@ -28,28 +172,20 @@ func (p *Provider) runMigrations( if len(migrations) == 0 { return nil, nil } - var apply []*migration + apply := migrations if byOne { - apply = []*migration{migrations[0]} - } else { - apply = migrations + apply = migrations[:1] } // Lazily parse SQL migrations (if any) in both directions. We do this before running any // migrations so that we can fail fast if there are any errors and avoid leaving the database in // a partially migrated state. - if err := parseSQL(p.fsys, false, apply); err != nil { return nil, err } - - // TODO(mf): If we decide to add support for advisory locks at the transaction level, this may + // feat(mf): If we decide to add support for advisory locks at the transaction level, this may // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe // to run in a transaction. - // - // - // - // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but // are locking the database with *sql.Conn. If the caller sets max open connections to 1, then // this will deadlock because the Go migration will try to acquire a connection from the pool, @@ -57,31 +193,32 @@ func (p *Provider) runMigrations( // // A potential solution is to expose a third Go register function *sql.Conn. Or continue to use // *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of - // an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 { - // for _, m := range apply { - // if m.IsGo() && !m.Go.UseTx { - // return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1") - // } - // } - // } + // an edge case. + if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 { + for _, m := range apply { + switch m.Source.Type { + case TypeGo: + if m.Go != nil && m.useTx(direction.ToBool()) { + return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1") + } + } + } + } - // Run migrations individually, opening a new transaction for each migration if the migration is - // safe to run in a transaction. - - // Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the - // impression that N migrations were applied when in fact some were not 2. Avoid the caller - // having to check for nil results + // Avoid allocating a slice because we may have a partial migration error. + // 1. Avoid giving the impression that N migrations were applied when in fact some were not + // 2. Avoid the caller having to check for nil results var results []*MigrationResult for _, m := range apply { current := &MigrationResult{ Source: m.Source, - Direction: strings.ToLower(direction.String()), - // TODO(mf): empty set here + Direction: direction.String(), + Empty: m.isEmpty(direction.ToBool()), } - start := time.Now() if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil { - // TODO(mf): we should also return the pending migrations here. + // TODO(mf): we should also return the pending migrations here, the remaining items in + // the apply slice. current.Error = err current.Duration = time.Since(start) return nil, &PartialError{ @@ -90,16 +227,12 @@ func (p *Provider) runMigrations( Err: err, } } - current.Duration = time.Since(start) results = append(results, current) } return results, nil } -// runIndividually runs an individual migration, opening a new transaction if the migration is safe -// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the -// supplied connection. func (p *Provider) runIndividually( ctx context.Context, conn *sql.Conn, @@ -182,8 +315,8 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err cleanup = func() error { p.mu.Unlock() // Use a detached context to unlock the session. This is because the context passed to - // SessionLock may have been canceled, and we don't want to cancel the unlock. - // TODO(mf): use [context.WithoutCancel] added in go1.21 + // SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf): + // use [context.WithoutCancel] added in go1.21 detachedCtx := context.Background() return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) } @@ -206,7 +339,7 @@ func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error { for _, m := range migrations { // If the migration is a SQL migration, and it has not been parsed, parse it. if m.Source.Type == TypeSQL && m.SQL == nil { - parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug) if err != nil { return err } @@ -248,11 +381,14 @@ type missingMigration struct { func findMissingMigrations( dbMigrations []*sqladapter.ListMigrationsResult, fsMigrations []*migration, - dbMaxVersion int64, ) []missingMigration { existing := make(map[int64]bool) + var dbMaxVersion int64 for _, m := range dbMigrations { existing[m.Version] = true + if m.Version > dbMaxVersion { + dbMaxVersion = m.Version + } } var missing []missingMigration for _, m := range fsMigrations { @@ -282,10 +418,6 @@ func (p *Provider) getMigration(version int64) (*migration, error) { } func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) { - if version < 1 { - return nil, errors.New("version must be greater than zero") - } - m, err := p.getMigration(version) if err != nil { return nil, err @@ -371,5 +503,8 @@ func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { if len(res) == 0 { return 0, nil } + sort.Slice(res, func(i, j int) bool { + return res[i].Version > res[j].Version + }) return res[0].Version, nil } diff --git a/internal/provider/run_down.go b/internal/provider/run_down.go deleted file mode 100644 index 011ba79..0000000 --- a/internal/provider/run_down.go +++ /dev/null @@ -1,53 +0,0 @@ -package provider - -import ( - "context" - - "github.com/pressly/goose/v3/internal/sqlparser" - "go.uber.org/multierr" -) - -func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) { - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - if len(p.migrations) == 0 { - return nil, nil - } - - if p.cfg.noVersioning { - var downMigrations []*migration - if downByOne { - downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1]) - } else { - downMigrations = p.migrations - } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) - } - - dbMigrations, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return nil, err - } - if dbMigrations[0].Version == 0 { - return nil, nil - } - - var downMigrations []*migration - for _, dbMigration := range dbMigrations { - if dbMigration.Version <= version { - break - } - m, err := p.getMigration(dbMigration.Version) - if err != nil { - return nil, err - } - downMigrations = append(downMigrations, m) - } - return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) -} diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go index 97e86ed..09c71cd 100644 --- a/internal/provider/run_test.go +++ b/internal/provider/run_test.go @@ -53,13 +53,13 @@ func TestProviderRun(t *testing.T) { p, _ := newProviderWithDB(t) _, err := p.UpTo(context.Background(), 0) check.HasError(t, err) - check.Equal(t, err.Error(), "version must be greater than zero") + check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0") _, err = p.DownTo(context.Background(), -1) check.HasError(t, err) - check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1") + check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1") _, err = p.ApplyVersion(context.Background(), 0, true) check.HasError(t, err) - check.Equal(t, err.Error(), "version must be greater than zero") + check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0") }) t.Run("up_and_down_all", func(t *testing.T) { ctx := context.Background() @@ -77,24 +77,24 @@ func TestProviderRun(t *testing.T) { res, err := p.Up(ctx) check.NoError(t, err) check.Number(t, len(res), numCount) - assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") - assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") - assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up") - assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up") - assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up") - assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up") - assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up") + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false) + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false) + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false) + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true) + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true) // Test Down res, err = p.DownTo(ctx, 0) check.NoError(t, err) check.Number(t, len(res), numCount) - assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down") - assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down") - assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down") - assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down") - assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down") - assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down") - assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down") + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true) + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true) + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false) + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false) + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false) + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false) + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false) }) t.Run("up_and_down_by_one", func(t *testing.T) { ctx := context.Background() @@ -112,8 +112,8 @@ func TestProviderRun(t *testing.T) { break } check.NoError(t, err) - check.Number(t, len(res), 1) - check.Number(t, res[0].Source.Version, int64(counter)) + check.NotNil(t, res) + check.Number(t, res.Source.Version, int64(counter)) } currentVersion, err := p.GetDBVersion(ctx) check.NoError(t, err) @@ -131,8 +131,8 @@ func TestProviderRun(t *testing.T) { break } check.NoError(t, err) - check.Number(t, len(res), 1) - check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1)) + check.NotNil(t, res) + check.Number(t, res.Source.Version, int64(maxVersion-counter+1)) } // Once everything is tested the version should match the highest testdata version currentVersion, err = p.GetDBVersion(ctx) @@ -148,8 +148,8 @@ func TestProviderRun(t *testing.T) { results, err := p.UpTo(ctx, upToVersion) check.NoError(t, err) check.Number(t, len(results), upToVersion) - assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") - assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) // Fetch the goose version from DB currentVersion, err := p.GetDBVersion(ctx) check.NoError(t, err) @@ -237,7 +237,11 @@ func TestProviderRun(t *testing.T) { res, err := p.ApplyVersion(ctx, s.Version, true) check.NoError(t, err) // Round-trip the migration result through the database to ensure it's valid. - assertResult(t, res, s, "up") + var empty bool + if s.Version == 6 || s.Version == 7 { + empty = true + } + assertResult(t, res, s, "up", empty) } // Apply all migrations in the down direction. for i := len(sources) - 1; i >= 0; i-- { @@ -245,7 +249,11 @@ func TestProviderRun(t *testing.T) { res, err := p.ApplyVersion(ctx, s.Version, false) check.NoError(t, err) // Round-trip the migration result through the database to ensure it's valid. - assertResult(t, res, s, "down") + var empty bool + if s.Version == 6 || s.Version == 7 { + empty = true + } + assertResult(t, res, s, "down", empty) } // Try apply version 1 multiple times _, err := p.ApplyVersion(ctx, 1, true) @@ -324,7 +332,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3'); check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") // Check Results field check.Number(t, len(expected.Applied), 1) - assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) // Check Failed field check.Bool(t, expected.Failed != nil, true) assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) @@ -368,11 +376,11 @@ func TestConcurrentProvider(t *testing.T) { t.Error(err) return } - if len(res) != 1 { - t.Errorf("expected 1 result, got %d", len(res)) + if res == nil { + t.Errorf("expected non-nil result, got nil") return } - ch <- res[0].Source.Version + ch <- res.Source.Version }() } go func() { @@ -623,13 +631,13 @@ func TestAllowMissing(t *testing.T) { // 4 upResult, err := p.UpByOne(ctx) check.NoError(t, err) - check.Number(t, len(upResult), 1) - check.Number(t, upResult[0].Source.Version, 4) + check.NotNil(t, upResult) + check.Number(t, upResult.Source.Version, 4) // 6 upResult, err = p.UpByOne(ctx) check.NoError(t, err) - check.Number(t, len(upResult), 1) - check.Number(t, upResult[0].Source.Version, 6) + check.NotNil(t, upResult) + check.Number(t, upResult.Source.Version, 6) count, err := getGooseVersionCount(db, provider.DefaultTablename) check.NoError(t, err) @@ -645,21 +653,29 @@ func TestAllowMissing(t *testing.T) { // So migrating down should be the reverse of the applied order: // 6,4,5,3,2,1 - expected := []int64{6, 4, 5, 3, 2, 1} - for i, v := range expected { - // TODO(mf): this is returning it by the order it was applied. - current, err := p.GetDBVersion(ctx) + testDownAndVersion := func(wantDBVersion, wantResultVersion int64) { + currentVersion, err := p.GetDBVersion(ctx) check.NoError(t, err) - check.Number(t, current, v) - downResult, err := p.Down(ctx) - if i == len(expected)-1 { - check.HasError(t, provider.ErrVersionNotFound) - } else { - check.NoError(t, err) - check.Number(t, len(downResult), 1) - check.Number(t, downResult[0].Source.Version, v) - } + check.Number(t, currentVersion, wantDBVersion) + downRes, err := p.Down(ctx) + check.NoError(t, err) + check.NotNil(t, downRes) + check.Number(t, downRes.Source.Version, wantResultVersion) } + + // This behaviour may need to change, see the following issues for more details: + // - https://github.com/pressly/goose/issues/523 + // - https://github.com/pressly/goose/issues/402 + + testDownAndVersion(6, 6) + testDownAndVersion(5, 4) // Ensure the max db version is 5 before down. + testDownAndVersion(5, 5) + testDownAndVersion(3, 3) + testDownAndVersion(2, 2) + testDownAndVersion(1, 1) + _, err = p.Down(ctx) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true) }) } @@ -674,7 +690,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) { } func TestGoOnly(t *testing.T) { - // Not parallel because it modifies global state. + // Not parallel because each subtest modifies global state. countUser := func(db *sql.DB) int { q := `SELECT count(*)FROM users` @@ -688,7 +704,7 @@ func TestGoOnly(t *testing.T) { ctx := context.Background() register := []*provider.Migration{ { - Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + Version: 1, Source: "00001_users_table.go", Registered: true, UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), DownFnContext: newTxFn("DROP TABLE users"), }, @@ -713,27 +729,23 @@ func TestGoOnly(t *testing.T) { // Apply migration 1 res, err := p.UpByOne(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) check.Number(t, countUser(db), 0) check.Bool(t, tableExists(t, db, "users"), true) // Apply migration 2 res, err = p.UpByOne(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false) check.Number(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false) check.Number(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) // Check table does not exist check.Bool(t, tableExists(t, db, "users"), false) }) @@ -741,7 +753,7 @@ func TestGoOnly(t *testing.T) { ctx := context.Background() register := []*provider.Migration{ { - Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + Version: 1, Source: "00001_users_table.go", Registered: true, UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), DownFnNoTxContext: newDBFn("DROP TABLE users"), }, @@ -766,27 +778,23 @@ func TestGoOnly(t *testing.T) { // Apply migration 1 res, err := p.UpByOne(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) check.Number(t, countUser(db), 0) check.Bool(t, tableExists(t, db, "users"), true) // Apply migration 2 res, err = p.UpByOne(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false) check.Number(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false) check.Number(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) check.NoError(t, err) - check.Number(t, len(res), 1) - assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) // Check table does not exist check.Bool(t, tableExists(t, db, "users"), false) }) @@ -880,7 +888,7 @@ func TestLockModeAdvisorySession(t *testing.T) { ) g.Go(func() error { for { - results, err := provider1.UpByOne(context.Background()) + result, err := provider1.UpByOne(context.Background()) if err != nil { if errors.Is(err, provider.ErrNoNextVersion) { return nil @@ -888,17 +896,15 @@ func TestLockModeAdvisorySession(t *testing.T) { return err } check.NoError(t, err) - if len(results) != 1 { - return fmt.Errorf("expected 1 result, got %d", len(results)) - } + check.NotNil(t, result) mu.Lock() - applied = append(applied, results[0].Source.Version) + applied = append(applied, result.Source.Version) mu.Unlock() } }) g.Go(func() error { for { - results, err := provider2.UpByOne(context.Background()) + result, err := provider2.UpByOne(context.Background()) if err != nil { if errors.Is(err, provider.ErrNoNextVersion) { return nil @@ -906,11 +912,9 @@ func TestLockModeAdvisorySession(t *testing.T) { return err } check.NoError(t, err) - if len(results) != 1 { - return fmt.Errorf("expected 1 result, got %d", len(results)) - } + check.NotNil(t, result) mu.Lock() - applied = append(applied, results[0].Source.Version) + applied = append(applied, result.Source.Version) mu.Unlock() } }) @@ -986,37 +990,33 @@ func TestLockModeAdvisorySession(t *testing.T) { ) g.Go(func() error { for { - results, err := provider1.Down(context.Background()) + result, err := provider1.Down(context.Background()) if err != nil { if errors.Is(err, provider.ErrNoNextVersion) { return nil } return err } - if len(results) != 1 { - return fmt.Errorf("expected 1 result, got %d", len(results)) - } check.NoError(t, err) + check.NotNil(t, result) mu.Lock() - applied = append(applied, results[0].Source.Version) + applied = append(applied, result.Source.Version) mu.Unlock() } }) g.Go(func() error { for { - results, err := provider2.Down(context.Background()) + result, err := provider2.Down(context.Background()) if err != nil { if errors.Is(err, provider.ErrNoNextVersion) { return nil } return err } - if len(results) != 1 { - return fmt.Errorf("expected 1 result, got %d", len(results)) - } check.NoError(t, err) + check.NotNil(t, result) mu.Lock() - applied = append(applied, results[0].Source.Version) + applied = append(applied, result.Source.Version) mu.Unlock() } }) @@ -1124,11 +1124,12 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) } -func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) { +func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) { t.Helper() + check.NotNil(t, got) check.Equal(t, got.Source, source) check.Equal(t, got.Direction, direction) - check.Equal(t, got.Empty, false) + check.Equal(t, got.Empty, isEmpty) check.Bool(t, got.Error == nil, true) check.Bool(t, got.Duration > 0, true) } @@ -1136,7 +1137,7 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) { t.Helper() check.Equal(t, got.Type, typ) - check.Equal(t, got.Fullpath, name) + check.Equal(t, got.Path, name) check.Equal(t, got.Version, version) switch got.Type { case provider.TypeGo: diff --git a/internal/provider/run_up.go b/internal/provider/run_up.go deleted file mode 100644 index 7ee9c6c..0000000 --- a/internal/provider/run_up.go +++ /dev/null @@ -1,96 +0,0 @@ -package provider - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/pressly/goose/v3/internal/sqlparser" - "go.uber.org/multierr" -) - -func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) { - if version < 1 { - return nil, errors.New("version must be greater than zero") - } - - conn, cleanup, err := p.initialize(ctx) - if err != nil { - return nil, err - } - defer func() { - retErr = multierr.Append(retErr, cleanup()) - }() - - if len(p.migrations) == 0 { - return nil, nil - } - if p.cfg.noVersioning { - // Short circuit if versioning is disabled and apply all migrations. - return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne) - } - - // optimize(mf): Listing all migrations from the database isn't great. This is only required to - // support the out-of-order (allow missing) feature. For users who don't use this feature, we - // could just query the database for the current version and then apply migrations that are - // greater than that version. - dbMigrations, err := p.store.ListMigrations(ctx, conn) - if err != nil { - return nil, err - } - dbMaxVersion := dbMigrations[0].Version - // lookupAppliedInDB is a map of all applied migrations in the database. - lookupAppliedInDB := make(map[int64]bool) - for _, m := range dbMigrations { - lookupAppliedInDB[m.Version] = true - } - - missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion) - - // feature(mf): It is very possible someone may want to apply ONLY new migrations and skip - // missing migrations entirely. At the moment this is not supported, but leaving this comment - // because that's where that logic will be handled. - if len(missingMigrations) > 0 && !p.cfg.allowMissing { - var collected []string - for _, v := range missingMigrations { - collected = append(collected, v.filename) - } - msg := "migration" - if len(collected) > 1 { - msg += "s" - } - return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]", - len(missingMigrations), msg, strings.Join(collected, ",")) - } - - var migrationsToApply []*migration - if p.cfg.allowMissing { - for _, v := range missingMigrations { - m, err := p.getMigration(v.versionID) - if err != nil { - return nil, err - } - migrationsToApply = append(migrationsToApply, m) - } - } - // filter all migrations with a version greater than the supplied version (min) and less than or - // equal to the requested version (max). - for _, m := range p.migrations { - if lookupAppliedInDB[m.Source.Version] { - continue - } - if m.Source.Version > dbMaxVersion && m.Source.Version <= version { - migrationsToApply = append(migrationsToApply, m) - } - } - - // feat(mf): this is where can (optionally) group multiple migrations to be run in a single - // transaction. The default is to apply each migration sequentially on its own. - // https://github.com/pressly/goose/issues/222 - // - // Note, we can't use a single transaction for all migrations because some may have to be run in - // their own transaction. - - return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne) -} diff --git a/internal/provider/types.go b/internal/provider/types.go index 21bb18b..979eac1 100644 --- a/internal/provider/types.go +++ b/internal/provider/types.go @@ -41,36 +41,23 @@ func (t MigrationType) String() string { // Source represents a single migration source. // -// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if -// the migration has a corresponding file on disk. It will be empty if the migration was registered -// manually. +// The Path field may be empty if the migration was registered manually. This is typically the case +// for Go migrations registered using the [WithGoMigration] option. type Source struct { - // Type is the type of migration. - Type MigrationType - // Full path to the migration file. - // - // Example: /path/to/migrations/001_create_users_table.sql - Fullpath string - // Version is the version of the migration. + Type MigrationType + Path string Version int64 } // MigrationResult is the result of a single migration operation. -// -// Note, the caller is responsible for checking the Error field for any errors that occurred while -// running the migration. If the Error field is not nil, the migration failed. type MigrationResult struct { Source Source Duration time.Duration Direction string - // Empty is true if the file was valid, but no statements to apply. These are still versioned - // migrations, but typically have no effect on the database. - // - // For SQL migrations, this means there was a valid .sql file but contained no statements. For - // Go migrations, this means the function was nil. + // Empty indicates no action was taken during the migration, but it was still versioned. For + // SQL, it means no statements; for Go, it's a nil function. Empty bool - - // Error is any error that occurred while running the migration. + // Error is only set if the migration failed. Error error } @@ -78,22 +65,20 @@ type MigrationResult struct { type State string const ( - // StatePending represents a migration that is on the filesystem, but not in the database. + // StatePending is a migration that exists on the filesystem, but not in the database. StatePending State = "pending" - // StateApplied represents a migration that is in BOTH the database and on the filesystem. + // StateApplied is a migration that has been applied to the database and exists on the + // filesystem. StateApplied State = "applied" - // StateUntracked represents a migration that is in the database, but not on the filesystem. - // StateUntracked State = "untracked" + // TODO(mf): we could also add a third state for untracked migrations. This would be useful for + // migrations that were manually applied to the database, but not versioned. Or the Source was + // deleted, but the migration still exists in the database. StateUntracked State = "untracked" ) // MigrationStatus represents the status of a single migration. type MigrationStatus struct { - // State is the state of the migration. - State State - // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or - // [StateUntracked]. + Source Source + State State AppliedAt time.Time - // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. - Source Source } diff --git a/internal/sqladapter/store_test.go b/internal/sqladapter/store_test.go index 1d01895..69d3d31 100644 --- a/internal/sqladapter/store_test.go +++ b/internal/sqladapter/store_test.go @@ -61,6 +61,25 @@ func TestStore(t *testing.T) { check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") }) }) + t.Run("ListMigrations", func(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + store, err := sqladapter.NewStore("sqlite3", "foo") + check.NoError(t, err) + err = store.CreateVersionTable(context.Background(), db) + check.NoError(t, err) + check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1)) + check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3)) + check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2)) + res, err := store.ListMigrations(context.Background(), db) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check versions are in descending order: [2, 3, 1] + check.Number(t, res[0].Version, 2) + check.Number(t, res[1].Version, 3) + check.Number(t, res[2].Version, 1) + }) } // testStore tests various store operations.