feat(experimental): Shuffle packages and tidy up (#619)

This commit is contained in:
Michael Fridman 2023-10-25 08:56:17 -04:00 committed by GitHub
parent 3b801a60c7
commit a9da7504fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 518 additions and 378 deletions

View File

@ -8,6 +8,13 @@ import (
"testing" "testing"
) )
func NotNil(t *testing.T, v any) {
t.Helper()
if v == nil {
t.Fatal("unexpected nil value")
}
}
func NoError(t *testing.T, err error) { func NoError(t *testing.T, err error) {
t.Helper() t.Helper()
if err != nil { if err != nil {

View File

@ -14,7 +14,7 @@ import (
func NewSource(t MigrationType, fullpath string, version int64) Source { func NewSource(t MigrationType, fullpath string, version int64) Source {
return Source{ return Source{
Type: t, Type: t,
Fullpath: fullpath, Path: fullpath,
Version: version, Version: version,
} }
} }
@ -133,7 +133,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
var unregistered []string var unregistered []string
for _, s := range sources.goSources { for _, s := range sources.goSources {
if _, ok := registerd[s.Version]; !ok { if _, ok := registerd[s.Version]; !ok {
unregistered = append(unregistered, s.Fullpath) unregistered = append(unregistered, s.Path)
} }
} }
if len(unregistered) > 0 { if len(unregistered) > 0 {
@ -149,7 +149,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
fullpath := r.fullpath fullpath := r.fullpath
if fullpath == "" { if fullpath == "" {
if s := sources.lookup(TypeGo, version); s != nil { if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Fullpath fullpath = s.Path
} }
} }
// Ensure there are no duplicate versions. // Ensure there are no duplicate versions.
@ -160,7 +160,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
} }
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version, version,
existing.Source.Fullpath, existing.Source.Path,
fullpath, fullpath,
) )
} }

View File

@ -6,6 +6,7 @@ import (
"testing/fstest" "testing/fstest"
"github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
) )
func TestCollectFileSources(t *testing.T) { func TestCollectFileSources(t *testing.T) {
@ -120,13 +121,13 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.sqlSources), 2) check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1) check.Number(t, len(sources.goSources), 1)
// 1 // 1
check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql") check.Equal(t, sources.sqlSources[0].Path, "1_foo.sql")
check.Equal(t, sources.sqlSources[0].Version, int64(1)) check.Equal(t, sources.sqlSources[0].Version, int64(1))
// 2 // 2
check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql") check.Equal(t, sources.sqlSources[1].Path, "5_qux.sql")
check.Equal(t, sources.sqlSources[1].Version, int64(5)) check.Equal(t, sources.sqlSources[1].Version, int64(5))
// 3 // 3
check.Equal(t, sources.goSources[0].Fullpath, "4_something.go") check.Equal(t, sources.goSources[0].Path, "4_something.go")
check.Equal(t, sources.goSources[0].Version, int64(4)) check.Equal(t, sources.goSources[0].Version, int64(4))
}) })
t.Run("duplicate_versions", func(t *testing.T) { t.Run("duplicate_versions", func(t *testing.T) {
@ -286,6 +287,65 @@ func TestMerge(t *testing.T) {
}) })
} }
func TestFindMissingMigrations(t *testing.T) {
t.Parallel()
t.Run("db_has_max_version", func(t *testing.T) {
// Test case: database has migrations 1, 3, 4, 5, 7
// Missing migrations: 2, 6
// Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8
dbMigrations := []*sqladapter.ListMigrationsResult{
{Version: 1},
{Version: 3},
{Version: 4},
{Version: 5},
{Version: 7}, // <-- database max version_id
}
fsMigrations := []*migration{
newMigration(1),
newMigration(2), // missing migration
newMigration(3),
newMigration(4),
newMigration(5),
newMigration(6), // missing migration
newMigration(7), // ----- database max version_id -----
newMigration(8), // new migration
}
got := findMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 2)
check.Number(t, got[1].versionID, 6)
// Sanity check.
check.Number(t, len(findMissingMigrations(nil, nil)), 0)
check.Number(t, len(findMissingMigrations(dbMigrations, nil)), 0)
check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0)
})
t.Run("fs_has_max_version", func(t *testing.T) {
dbMigrations := []*sqladapter.ListMigrationsResult{
{Version: 1},
{Version: 5},
{Version: 2},
}
fsMigrations := []*migration{
newMigration(3), // new migration
newMigration(4), // new migration
}
got := findMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 3)
check.Number(t, got[1].versionID, 4)
})
}
func newMigration(version int64) *migration {
return &migration{
Source: Source{
Version: version,
},
}
}
func assertMigration(t *testing.T, got *migration, want Source) { func assertMigration(t *testing.T, got *migration, want Source) {
t.Helper() t.Helper()
check.Equal(t, got.Source, want) check.Equal(t, got.Source, want)

View File

@ -22,18 +22,19 @@ var (
// PartialError is returned when a migration fails, but some migrations already got applied. // PartialError is returned when a migration fails, but some migrations already got applied.
type PartialError struct { type PartialError struct {
// Applied are migrations that were applied successfully before the error occurred. // Applied are migrations that were applied successfully before the error occurred. May be
// empty.
Applied []*MigrationResult Applied []*MigrationResult
// Failed contains the result of the migration that failed. // Failed contains the result of the migration that failed. Cannot be nil.
Failed *MigrationResult Failed *MigrationResult
// Err is the error that occurred while running the migration. // Err is the error that occurred while running the migration and caused the failure.
Err error Err error
} }
func (e *PartialError) Error() string { func (e *PartialError) Error() string {
filename := "(file unknown)" filename := "(file unknown)"
if e.Failed != nil && e.Failed.Source.Fullpath != "" { if e.Failed != nil && e.Failed.Source.Path != "" {
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath)) filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path))
} }
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err) return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
} }

View File

@ -29,7 +29,7 @@ func (m *migration) useTx(direction bool) bool {
case TypeSQL: case TypeSQL:
return m.SQL.UseTx return m.SQL.UseTx
case TypeGo: case TypeGo:
if m.Go == nil { if m.Go == nil || m.Go.isEmpty(direction) {
return false return false
} }
if direction { if direction {
@ -41,8 +41,18 @@ func (m *migration) useTx(direction bool) bool {
return false return false
} }
func (m *migration) isEmpty(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL == nil || m.SQL.IsEmpty(direction)
case TypeGo:
return m.Go == nil || m.Go.isEmpty(direction)
}
return true
}
func (m *migration) filename() string { func (m *migration) filename() string {
return filepath.Base(m.Source.Fullpath) return filepath.Base(m.Source.Path)
} }
// run runs the migration inside of a transaction. // run runs the migration inside of a transaction.
@ -57,7 +67,7 @@ func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
return m.Go.run(ctx, tx, direction) return m.Go.run(ctx, tx, direction)
} }
// This should never happen. // This should never happen.
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
} }
// runNoTx runs the migration without a transaction. // runNoTx runs the migration without a transaction.
@ -72,7 +82,7 @@ func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) err
return m.Go.runNoTx(ctx, db, direction) return m.Go.runNoTx(ctx, db, direction)
} }
// This should never happen. // This should never happen.
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
} }
// runConn runs the migration without a transaction using the provided connection. // runConn runs the migration without a transaction using the provided connection.
@ -87,7 +97,7 @@ func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool)
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
} }
// This should never happen. // This should never happen.
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
} }
type goMigration struct { type goMigration struct {
@ -95,6 +105,16 @@ type goMigration struct {
up, down *GoMigration up, down *GoMigration
} }
func (g *goMigration) isEmpty(direction bool) bool {
if g.up == nil && g.down == nil {
panic("go migration has no up or down")
}
if direction {
return g.up == nil
}
return g.down == nil
}
func newGoMigration(fullpath string, up, down *GoMigration) *goMigration { func newGoMigration(fullpath string, up, down *GoMigration) *goMigration {
return &goMigration{ return &goMigration{
fullpath: fullpath, fullpath: fullpath,

View File

@ -11,21 +11,45 @@ type Migration struct {
Version int64 Version int64
Source string // path to .sql script or go file Source string // path to .sql script or go file
Registered bool Registered bool
UseTx bool UpFnContext, DownFnContext func(context.Context, *sql.Tx) error
UpFnContext func(context.Context, *sql.Tx) error UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
DownFnContext func(context.Context, *sql.Tx) error
UpFnNoTxContext func(context.Context, *sql.DB) error
DownFnNoTxContext func(context.Context, *sql.DB) error
} }
var registeredGoMigrations = make(map[int64]*Migration) var registeredGoMigrations = make(map[int64]*Migration)
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
// the migrations are nil or if a migration with the same version has already been registered.
//
// Not safe for concurrent use.
func SetGlobalGoMigrations(migrations []*Migration) error { func SetGlobalGoMigrations(migrations []*Migration) error {
for _, m := range migrations { for _, m := range migrations {
if m == nil { if m == nil {
return errors.New("cannot register nil go migration") return errors.New("cannot register nil go migration")
} }
if m.Version < 1 {
return errors.New("migration versions must be greater than zero")
}
if !m.Registered {
return errors.New("migration must be registered")
}
if m.Source != "" {
// If the source is set, expect it to be a file path with a numeric component that
// matches the version.
version, err := NumericComponent(m.Source)
if err != nil {
return err
}
if version != m.Version {
return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source)
}
}
// It's valid for all of these to be nil.
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
if _, ok := registeredGoMigrations[m.Version]; ok { if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version) return fmt.Errorf("go migration with version %d already registered", m.Version)
} }
@ -34,6 +58,9 @@ func SetGlobalGoMigrations(migrations []*Migration) error {
return nil return nil
} }
// ResetGlobalGoMigrations resets the global go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalGoMigrations() { func ResetGlobalGoMigrations() {
registeredGoMigrations = make(map[int64]*Migration) registeredGoMigrations = make(map[int64]*Migration)
} }

View File

@ -19,8 +19,9 @@ import (
// github.com/lib/pq or github.com/jackc/pgx. // github.com/lib/pq or github.com/jackc/pgx.
// //
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to // fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is // use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub. // However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
// migrations using [fs.Sub].
// //
// See [ProviderOption] for more information on configuring the provider. // See [ProviderOption] for more information on configuring the provider.
// //
@ -39,6 +40,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
} }
cfg := config{ cfg := config{
registered: make(map[int64]*goMigration), registered: make(map[int64]*goMigration),
excludes: make(map[string]bool),
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt.apply(&cfg); err != nil { if err := opt.apply(&cfg); err != nil {
@ -137,6 +139,8 @@ type Provider struct {
fsys fs.FS fsys fs.FS
cfg config cfg config
store sqladapter.Store store sqladapter.Store
// migrations are ordered by version in ascending order.
migrations []*migration migrations []*migration
} }
@ -149,8 +153,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
// GetDBVersion returns the max version from the database, regardless of the applied order. For // GetDBVersion returns the max version from the database, regardless of the applied order. For
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been // example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
// applied, it returns 0. // applied, it returns 0.
//
// TODO(mf): this is not true?
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
return p.getDBVersion(ctx) return p.getDBVersion(ctx)
} }
@ -175,25 +177,28 @@ func (p *Provider) Close() error {
return p.db.Close() return p.db.Close()
} }
// ApplyVersion applies exactly one migration at the specified version. If there is no source for // ApplyVersion applies exactly one migration by version. If there is no source for the specified
// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been // version, this method returns [ErrVersionNotFound]. If the migration has been applied already,
// applied already, this method returns [ErrAlreadyApplied]. // this method returns [ErrAlreadyApplied].
// //
// When direction is true, the up migration is executed, and when direction is false, the down // When direction is true, the up migration is executed, and when direction is false, the down
// migration is executed. // migration is executed.
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.apply(ctx, version, direction) return p.apply(ctx, version, direction)
} }
// Up applies all pending migrations. If there are no new migrations to apply, this method returns // Up applies all [StatePending] migrations. If there are no new migrations to apply, this method
// empty list and nil error. // returns empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
return p.up(ctx, false, math.MaxInt64) return p.up(ctx, false, math.MaxInt64)
} }
// UpByOne applies the next available migration. If there are no migrations to apply, this method // UpByOne applies the next available migration. If there are no migrations to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. // returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) { func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
res, err := p.up(ctx, true, math.MaxInt64) res, err := p.up(ctx, true, math.MaxInt64)
if err != nil { if err != nil {
return nil, err return nil, err
@ -201,21 +206,28 @@ func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) {
if len(res) == 0 { if len(res) == 0 {
return nil, ErrNoNextVersion return nil, ErrNoNextVersion
} }
return res, nil // This should never happen. We should always have exactly one result and test for this.
if len(res) > 1 {
return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res))
}
return res[0], nil
} }
// UpTo applies all available migrations up to and including the specified version. If there are no // UpTo applies all available migrations up to, and including, the specified version. If there are
// migrations to apply, this method returns empty list and nil error. // no migrations to apply, this method returns empty list and nil error.
// //
// For instance, if there are three new migrations (9,10,11) and the current database version is 8 // For instance, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9 and 10 will be applied. // with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.up(ctx, false, version) return p.up(ctx, false, version)
} }
// Down rolls back the most recently applied migration. If there are no migrations to apply, this // Down rolls back the most recently applied migration. If there are no migrations to apply, this
// method returns [ErrNoNextVersion]. // method returns [ErrNoNextVersion].
func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) { func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
res, err := p.down(ctx, true, 0) res, err := p.down(ctx, true, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@ -223,16 +235,19 @@ func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) {
if len(res) == 0 { if len(res) == 0 {
return nil, ErrNoNextVersion return nil, ErrNoNextVersion
} }
return res, nil if len(res) > 1 {
return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res))
}
return res[0], nil
} }
// DownTo rolls back all migrations down to but not including the specified version. // DownTo rolls back all migrations down to, but not including, the specified version.
// //
// For instance, if the current database version is 11, and the requested version is 9, only // For instance, if the current database version is 11,10,9... and the requested version is 9, only
// migrations 11 and 10 will be rolled back. // migrations 11, 10 will be rolled back.
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 0 { if version < 0 {
return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version) return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
} }
return p.down(ctx, false, version) return p.down(ctx, false, version)
} }

View File

@ -10,6 +10,9 @@ import (
) )
const ( const (
// DefaultTablename is the default name of the database table used to track history of applied
// migrations. It can be overridden using the [WithTableName] option when creating a new
// provider.
DefaultTablename = "goose_db_version" DefaultTablename = "goose_db_version"
) )
@ -85,7 +88,8 @@ type GoMigration struct {
// WithGoMigration registers a Go migration with the given version. // WithGoMigration registers a Go migration with the given version.
// //
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up // If WithGoMigration is called multiple times with the same version, an error is returned. Both up
// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set. // and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
// set.
func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
return configFunc(func(c *config) error { return configFunc(func(c *config) error {
if version < 1 { if version < 1 {
@ -122,9 +126,10 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
// WithAllowMissing allows the provider to apply missing (out-of-order) migrations. // WithAllowMissing allows the provider to apply missing (out-of-order) migrations.
// //
// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is // Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied // then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
// migrations will be: 1,6,2,3,5. // applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
// by new migrations.
func WithAllowMissing(b bool) ProviderOption { func WithAllowMissing(b bool) ProviderOption {
return configFunc(func(c *config) error { return configFunc(func(c *config) error {
c.allowMissing = b c.allowMissing = b
@ -132,9 +137,9 @@ func WithAllowMissing(b bool) ProviderOption {
}) })
} }
// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations // WithNoVersioning disables versioning. Disabling versioning allows applying migrations without
// without tracking the versions in the database schema table. Useful for tests, seeding a database // tracking the versions in the database schema table. Useful for tests, seeding a database or
// or running ad-hoc queries. // running ad-hoc queries.
func WithNoVersioning(b bool) ProviderOption { func WithNoVersioning(b bool) ProviderOption {
return configFunc(func(c *config) error { return configFunc(func(c *config) error {
c.noVersioning = b c.noVersioning = b

View File

@ -41,7 +41,7 @@ func TestProvider(t *testing.T) {
// Not parallel because it modifies global state. // Not parallel because it modifies global state.
register := []*provider.Migration{ register := []*provider.Migration{
{ {
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: nil, UpFnContext: nil,
DownFnContext: nil, DownFnContext: nil,
}, },
@ -69,32 +69,46 @@ func TestProvider(t *testing.T) {
t.Run("duplicate_up", func(t *testing.T) { t.Run("duplicate_up", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{ err := provider.SetGlobalGoMigrations([]*provider.Migration{
{ {
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: func(context.Context, *sql.Tx) error { return nil }, UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
}, },
}) })
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations) t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
check.HasError(t, err) check.HasError(t, err)
check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext") check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
}) })
t.Run("duplicate_down", func(t *testing.T) { t.Run("duplicate_down", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{ err := provider.SetGlobalGoMigrations([]*provider.Migration{
{ {
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, Version: 1, Source: "00001_users_table.go", Registered: true,
DownFnContext: func(context.Context, *sql.Tx) error { return nil }, DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
}, },
}) })
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations) t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
check.HasError(t, err) check.HasError(t, err)
check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext") check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
})
t.Run("not_registered", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
{
Version: 1, Source: "00001_users_table.go",
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration must be registered")
})
t.Run("zero_not_allowed", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
{
Version: 0,
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration versions must be greater than zero")
}) })
} }

View File

@ -15,9 +15,153 @@ import (
"go.uber.org/multierr" "go.uber.org/multierr"
) )
var (
errMissingZeroVersion = errors.New("missing zero version migration")
)
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
var apply []*migration
if p.cfg.noVersioning {
apply = p.migrations
} else {
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
// support the allow missing (out-of-order) feature. For users that don't use this feature, we
// could just query the database for the current max version and then apply migrations greater
// than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
apply, err = p.resolveUpMigrations(dbMigrations, version)
if err != nil {
return nil, err
}
}
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own.
// https://github.com/pressly/goose/issues/222
//
// Careful, we can't use a single transaction for all migrations because some may have to be run
// in their own transaction.
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
}
func (p *Provider) resolveUpMigrations(
dbVersions []*sqladapter.ListMigrationsResult,
version int64,
) ([]*migration, error) {
var apply []*migration
var dbMaxVersion int64
// dbAppliedVersions is a map of all applied migrations in the database.
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
for _, m := range dbVersions {
dbAppliedVersions[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
missingMigrations := findMissingMigrations(dbVersions, p.migrations)
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
// migrations entirely. At the moment this is not supported, but leaving this comment because
// that's where that logic would be handled.
//
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3.
// Not sure if this is a common use case, but it's possible.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, v.filename)
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
apply = append(apply, m)
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max). Skip any migrations that have already been applied.
for _, m := range p.migrations {
if dbAppliedVersions[m.Source.Version] {
continue
}
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
apply = append(apply, m)
}
}
return apply, nil
}
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
downMigrations := p.migrations
if downByOne {
last := p.migrations[len(p.migrations)-1]
downMigrations = []*migration{last}
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
if dbMigrations[0].Version == 0 {
return nil, nil
}
var downMigrations []*migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
downMigrations = append(downMigrations, m)
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
// runMigrations runs migrations sequentially in the given direction. // runMigrations runs migrations sequentially in the given direction.
// //
// If the migrations slice is empty, this function returns nil with no error. // If the migrations list is empty, return nil without error.
func (p *Provider) runMigrations( func (p *Provider) runMigrations(
ctx context.Context, ctx context.Context,
conn *sql.Conn, conn *sql.Conn,
@ -28,28 +172,20 @@ func (p *Provider) runMigrations(
if len(migrations) == 0 { if len(migrations) == 0 {
return nil, nil return nil, nil
} }
var apply []*migration apply := migrations
if byOne { if byOne {
apply = []*migration{migrations[0]} apply = migrations[:1]
} else {
apply = migrations
} }
// Lazily parse SQL migrations (if any) in both directions. We do this before running any // Lazily parse SQL migrations (if any) in both directions. We do this before running any
// migrations so that we can fail fast if there are any errors and avoid leaving the database in // migrations so that we can fail fast if there are any errors and avoid leaving the database in
// a partially migrated state. // a partially migrated state.
if err := parseSQL(p.fsys, false, apply); err != nil { if err := parseSQL(p.fsys, false, apply); err != nil {
return nil, err return nil, err
} }
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
// TODO(mf): If we decide to add support for advisory locks at the transaction level, this may
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
// to run in a transaction. // to run in a transaction.
//
//
//
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but
// are locking the database with *sql.Conn. If the caller sets max open connections to 1, then // are locking the database with *sql.Conn. If the caller sets max open connections to 1, then
// this will deadlock because the Go migration will try to acquire a connection from the pool, // this will deadlock because the Go migration will try to acquire a connection from the pool,
@ -57,31 +193,32 @@ func (p *Provider) runMigrations(
// //
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to use // A potential solution is to expose a third Go register function *sql.Conn. Or continue to use
// *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of // *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of
// an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 { // an edge case.
// for _, m := range apply { if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
// if m.IsGo() && !m.Go.UseTx { for _, m := range apply {
// return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1") switch m.Source.Type {
// } case TypeGo:
// } if m.Go != nil && m.useTx(direction.ToBool()) {
// } return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1")
}
}
}
}
// Run migrations individually, opening a new transaction for each migration if the migration is // Avoid allocating a slice because we may have a partial migration error.
// safe to run in a transaction. // 1. Avoid giving the impression that N migrations were applied when in fact some were not
// 2. Avoid the caller having to check for nil results
// Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the
// impression that N migrations were applied when in fact some were not 2. Avoid the caller
// having to check for nil results
var results []*MigrationResult var results []*MigrationResult
for _, m := range apply { for _, m := range apply {
current := &MigrationResult{ current := &MigrationResult{
Source: m.Source, Source: m.Source,
Direction: strings.ToLower(direction.String()), Direction: direction.String(),
// TODO(mf): empty set here Empty: m.isEmpty(direction.ToBool()),
} }
start := time.Now() start := time.Now()
if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil { if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil {
// TODO(mf): we should also return the pending migrations here. // TODO(mf): we should also return the pending migrations here, the remaining items in
// the apply slice.
current.Error = err current.Error = err
current.Duration = time.Since(start) current.Duration = time.Since(start)
return nil, &PartialError{ return nil, &PartialError{
@ -90,16 +227,12 @@ func (p *Provider) runMigrations(
Err: err, Err: err,
} }
} }
current.Duration = time.Since(start) current.Duration = time.Since(start)
results = append(results, current) results = append(results, current)
} }
return results, nil return results, nil
} }
// runIndividually runs an individual migration, opening a new transaction if the migration is safe
// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the
// supplied connection.
func (p *Provider) runIndividually( func (p *Provider) runIndividually(
ctx context.Context, ctx context.Context,
conn *sql.Conn, conn *sql.Conn,
@ -182,8 +315,8 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
cleanup = func() error { cleanup = func() error {
p.mu.Unlock() p.mu.Unlock()
// Use a detached context to unlock the session. This is because the context passed to // Use a detached context to unlock the session. This is because the context passed to
// SessionLock may have been canceled, and we don't want to cancel the unlock. // SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf):
// TODO(mf): use [context.WithoutCancel] added in go1.21 // use [context.WithoutCancel] added in go1.21
detachedCtx := context.Background() detachedCtx := context.Background()
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
} }
@ -206,7 +339,7 @@ func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
for _, m := range migrations { for _, m := range migrations {
// If the migration is a SQL migration, and it has not been parsed, parse it. // If the migration is a SQL migration, and it has not been parsed, parse it.
if m.Source.Type == TypeSQL && m.SQL == nil { if m.Source.Type == TypeSQL && m.SQL == nil {
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug)
if err != nil { if err != nil {
return err return err
} }
@ -248,11 +381,14 @@ type missingMigration struct {
func findMissingMigrations( func findMissingMigrations(
dbMigrations []*sqladapter.ListMigrationsResult, dbMigrations []*sqladapter.ListMigrationsResult,
fsMigrations []*migration, fsMigrations []*migration,
dbMaxVersion int64,
) []missingMigration { ) []missingMigration {
existing := make(map[int64]bool) existing := make(map[int64]bool)
var dbMaxVersion int64
for _, m := range dbMigrations { for _, m := range dbMigrations {
existing[m.Version] = true existing[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
} }
var missing []missingMigration var missing []missingMigration
for _, m := range fsMigrations { for _, m := range fsMigrations {
@ -282,10 +418,6 @@ func (p *Provider) getMigration(version int64) (*migration, error) {
} }
func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) { func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
m, err := p.getMigration(version) m, err := p.getMigration(version)
if err != nil { if err != nil {
return nil, err return nil, err
@ -371,5 +503,8 @@ func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) {
if len(res) == 0 { if len(res) == 0 {
return 0, nil return 0, nil
} }
sort.Slice(res, func(i, j int) bool {
return res[i].Version > res[j].Version
})
return res[0].Version, nil return res[0].Version, nil
} }

View File

@ -1,53 +0,0 @@
package provider
import (
"context"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
var downMigrations []*migration
if downByOne {
downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1])
} else {
downMigrations = p.migrations
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if dbMigrations[0].Version == 0 {
return nil, nil
}
var downMigrations []*migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
downMigrations = append(downMigrations, m)
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}

View File

@ -53,13 +53,13 @@ func TestProviderRun(t *testing.T) {
p, _ := newProviderWithDB(t) p, _ := newProviderWithDB(t)
_, err := p.UpTo(context.Background(), 0) _, err := p.UpTo(context.Background(), 0)
check.HasError(t, err) check.HasError(t, err)
check.Equal(t, err.Error(), "version must be greater than zero") check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
_, err = p.DownTo(context.Background(), -1) _, err = p.DownTo(context.Background(), -1)
check.HasError(t, err) check.HasError(t, err)
check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1") check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
_, err = p.ApplyVersion(context.Background(), 0, true) _, err = p.ApplyVersion(context.Background(), 0, true)
check.HasError(t, err) check.HasError(t, err)
check.Equal(t, err.Error(), "version must be greater than zero") check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
}) })
t.Run("up_and_down_all", func(t *testing.T) { t.Run("up_and_down_all", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@ -77,24 +77,24 @@ func TestProviderRun(t *testing.T) {
res, err := p.Up(ctx) res, err := p.Up(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), numCount) check.Number(t, len(res), numCount)
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up") assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up") assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up") assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up") assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up") assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
// Test Down // Test Down
res, err = p.DownTo(ctx, 0) res, err = p.DownTo(ctx, 0)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), numCount) check.Number(t, len(res), numCount)
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down") assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down") assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down") assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down") assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down") assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down") assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down") assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
}) })
t.Run("up_and_down_by_one", func(t *testing.T) { t.Run("up_and_down_by_one", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@ -112,8 +112,8 @@ func TestProviderRun(t *testing.T) {
break break
} }
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) check.NotNil(t, res)
check.Number(t, res[0].Source.Version, int64(counter)) check.Number(t, res.Source.Version, int64(counter))
} }
currentVersion, err := p.GetDBVersion(ctx) currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err) check.NoError(t, err)
@ -131,8 +131,8 @@ func TestProviderRun(t *testing.T) {
break break
} }
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) check.NotNil(t, res)
check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1)) check.Number(t, res.Source.Version, int64(maxVersion-counter+1))
} }
// Once everything is tested the version should match the highest testdata version // Once everything is tested the version should match the highest testdata version
currentVersion, err = p.GetDBVersion(ctx) currentVersion, err = p.GetDBVersion(ctx)
@ -148,8 +148,8 @@ func TestProviderRun(t *testing.T) {
results, err := p.UpTo(ctx, upToVersion) results, err := p.UpTo(ctx, upToVersion)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(results), upToVersion) check.Number(t, len(results), upToVersion)
assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
// Fetch the goose version from DB // Fetch the goose version from DB
currentVersion, err := p.GetDBVersion(ctx) currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err) check.NoError(t, err)
@ -237,7 +237,11 @@ func TestProviderRun(t *testing.T) {
res, err := p.ApplyVersion(ctx, s.Version, true) res, err := p.ApplyVersion(ctx, s.Version, true)
check.NoError(t, err) check.NoError(t, err)
// Round-trip the migration result through the database to ensure it's valid. // Round-trip the migration result through the database to ensure it's valid.
assertResult(t, res, s, "up") var empty bool
if s.Version == 6 || s.Version == 7 {
empty = true
}
assertResult(t, res, s, "up", empty)
} }
// Apply all migrations in the down direction. // Apply all migrations in the down direction.
for i := len(sources) - 1; i >= 0; i-- { for i := len(sources) - 1; i >= 0; i-- {
@ -245,7 +249,11 @@ func TestProviderRun(t *testing.T) {
res, err := p.ApplyVersion(ctx, s.Version, false) res, err := p.ApplyVersion(ctx, s.Version, false)
check.NoError(t, err) check.NoError(t, err)
// Round-trip the migration result through the database to ensure it's valid. // Round-trip the migration result through the database to ensure it's valid.
assertResult(t, res, s, "down") var empty bool
if s.Version == 6 || s.Version == 7 {
empty = true
}
assertResult(t, res, s, "down", empty)
} }
// Try apply version 1 multiple times // Try apply version 1 multiple times
_, err := p.ApplyVersion(ctx, 1, true) _, err := p.ApplyVersion(ctx, 1, true)
@ -324,7 +332,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
// Check Results field // Check Results field
check.Number(t, len(expected.Applied), 1) check.Number(t, len(expected.Applied), 1)
assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
// Check Failed field // Check Failed field
check.Bool(t, expected.Failed != nil, true) check.Bool(t, expected.Failed != nil, true)
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
@ -368,11 +376,11 @@ func TestConcurrentProvider(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
if len(res) != 1 { if res == nil {
t.Errorf("expected 1 result, got %d", len(res)) t.Errorf("expected non-nil result, got nil")
return return
} }
ch <- res[0].Source.Version ch <- res.Source.Version
}() }()
} }
go func() { go func() {
@ -623,13 +631,13 @@ func TestAllowMissing(t *testing.T) {
// 4 // 4
upResult, err := p.UpByOne(ctx) upResult, err := p.UpByOne(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(upResult), 1) check.NotNil(t, upResult)
check.Number(t, upResult[0].Source.Version, 4) check.Number(t, upResult.Source.Version, 4)
// 6 // 6
upResult, err = p.UpByOne(ctx) upResult, err = p.UpByOne(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(upResult), 1) check.NotNil(t, upResult)
check.Number(t, upResult[0].Source.Version, 6) check.Number(t, upResult.Source.Version, 6)
count, err := getGooseVersionCount(db, provider.DefaultTablename) count, err := getGooseVersionCount(db, provider.DefaultTablename)
check.NoError(t, err) check.NoError(t, err)
@ -645,21 +653,29 @@ func TestAllowMissing(t *testing.T) {
// So migrating down should be the reverse of the applied order: // So migrating down should be the reverse of the applied order:
// 6,4,5,3,2,1 // 6,4,5,3,2,1
expected := []int64{6, 4, 5, 3, 2, 1} testDownAndVersion := func(wantDBVersion, wantResultVersion int64) {
for i, v := range expected { currentVersion, err := p.GetDBVersion(ctx)
// TODO(mf): this is returning it by the order it was applied.
current, err := p.GetDBVersion(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, current, v) check.Number(t, currentVersion, wantDBVersion)
downResult, err := p.Down(ctx) downRes, err := p.Down(ctx)
if i == len(expected)-1 {
check.HasError(t, provider.ErrVersionNotFound)
} else {
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(downResult), 1) check.NotNil(t, downRes)
check.Number(t, downResult[0].Source.Version, v) check.Number(t, downRes.Source.Version, wantResultVersion)
}
} }
// This behaviour may need to change, see the following issues for more details:
// - https://github.com/pressly/goose/issues/523
// - https://github.com/pressly/goose/issues/402
testDownAndVersion(6, 6)
testDownAndVersion(5, 4) // Ensure the max db version is 5 before down.
testDownAndVersion(5, 5)
testDownAndVersion(3, 3)
testDownAndVersion(2, 2)
testDownAndVersion(1, 1)
_, err = p.Down(ctx)
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true)
}) })
} }
@ -674,7 +690,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
} }
func TestGoOnly(t *testing.T) { func TestGoOnly(t *testing.T) {
// Not parallel because it modifies global state. // Not parallel because each subtest modifies global state.
countUser := func(db *sql.DB) int { countUser := func(db *sql.DB) int {
q := `SELECT count(*)FROM users` q := `SELECT count(*)FROM users`
@ -688,7 +704,7 @@ func TestGoOnly(t *testing.T) {
ctx := context.Background() ctx := context.Background()
register := []*provider.Migration{ register := []*provider.Migration{
{ {
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
DownFnContext: newTxFn("DROP TABLE users"), DownFnContext: newTxFn("DROP TABLE users"),
}, },
@ -713,27 +729,23 @@ func TestGoOnly(t *testing.T) {
// Apply migration 1 // Apply migration 1
res, err := p.UpByOne(ctx) res, err := p.UpByOne(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up")
check.Number(t, countUser(db), 0) check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true) check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2 // Apply migration 2
res, err = p.UpByOne(ctx) res, err = p.UpByOne(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up")
check.Number(t, countUser(db), 3) check.Number(t, countUser(db), 3)
// Rollback migration 2 // Rollback migration 2
res, err = p.Down(ctx) res, err = p.Down(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down")
check.Number(t, countUser(db), 0) check.Number(t, countUser(db), 0)
// Rollback migration 1 // Rollback migration 1
res, err = p.Down(ctx) res, err = p.Down(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down")
// Check table does not exist // Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false) check.Bool(t, tableExists(t, db, "users"), false)
}) })
@ -741,7 +753,7 @@ func TestGoOnly(t *testing.T) {
ctx := context.Background() ctx := context.Background()
register := []*provider.Migration{ register := []*provider.Migration{
{ {
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
DownFnNoTxContext: newDBFn("DROP TABLE users"), DownFnNoTxContext: newDBFn("DROP TABLE users"),
}, },
@ -766,27 +778,23 @@ func TestGoOnly(t *testing.T) {
// Apply migration 1 // Apply migration 1
res, err := p.UpByOne(ctx) res, err := p.UpByOne(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up")
check.Number(t, countUser(db), 0) check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true) check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2 // Apply migration 2
res, err = p.UpByOne(ctx) res, err = p.UpByOne(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up")
check.Number(t, countUser(db), 3) check.Number(t, countUser(db), 3)
// Rollback migration 2 // Rollback migration 2
res, err = p.Down(ctx) res, err = p.Down(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down")
check.Number(t, countUser(db), 0) check.Number(t, countUser(db), 0)
// Rollback migration 1 // Rollback migration 1
res, err = p.Down(ctx) res, err = p.Down(ctx)
check.NoError(t, err) check.NoError(t, err)
check.Number(t, len(res), 1) assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down")
// Check table does not exist // Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false) check.Bool(t, tableExists(t, db, "users"), false)
}) })
@ -880,7 +888,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
) )
g.Go(func() error { g.Go(func() error {
for { for {
results, err := provider1.UpByOne(context.Background()) result, err := provider1.UpByOne(context.Background())
if err != nil { if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) { if errors.Is(err, provider.ErrNoNextVersion) {
return nil return nil
@ -888,17 +896,15 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err return err
} }
check.NoError(t, err) check.NoError(t, err)
if len(results) != 1 { check.NotNil(t, result)
return fmt.Errorf("expected 1 result, got %d", len(results))
}
mu.Lock() mu.Lock()
applied = append(applied, results[0].Source.Version) applied = append(applied, result.Source.Version)
mu.Unlock() mu.Unlock()
} }
}) })
g.Go(func() error { g.Go(func() error {
for { for {
results, err := provider2.UpByOne(context.Background()) result, err := provider2.UpByOne(context.Background())
if err != nil { if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) { if errors.Is(err, provider.ErrNoNextVersion) {
return nil return nil
@ -906,11 +912,9 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err return err
} }
check.NoError(t, err) check.NoError(t, err)
if len(results) != 1 { check.NotNil(t, result)
return fmt.Errorf("expected 1 result, got %d", len(results))
}
mu.Lock() mu.Lock()
applied = append(applied, results[0].Source.Version) applied = append(applied, result.Source.Version)
mu.Unlock() mu.Unlock()
} }
}) })
@ -986,37 +990,33 @@ func TestLockModeAdvisorySession(t *testing.T) {
) )
g.Go(func() error { g.Go(func() error {
for { for {
results, err := provider1.Down(context.Background()) result, err := provider1.Down(context.Background())
if err != nil { if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) { if errors.Is(err, provider.ErrNoNextVersion) {
return nil return nil
} }
return err return err
} }
if len(results) != 1 {
return fmt.Errorf("expected 1 result, got %d", len(results))
}
check.NoError(t, err) check.NoError(t, err)
check.NotNil(t, result)
mu.Lock() mu.Lock()
applied = append(applied, results[0].Source.Version) applied = append(applied, result.Source.Version)
mu.Unlock() mu.Unlock()
} }
}) })
g.Go(func() error { g.Go(func() error {
for { for {
results, err := provider2.Down(context.Background()) result, err := provider2.Down(context.Background())
if err != nil { if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) { if errors.Is(err, provider.ErrNoNextVersion) {
return nil return nil
} }
return err return err
} }
if len(results) != 1 {
return fmt.Errorf("expected 1 result, got %d", len(results))
}
check.NoError(t, err) check.NoError(t, err)
check.NotNil(t, result)
mu.Lock() mu.Lock()
applied = append(applied, results[0].Source.Version) applied = append(applied, result.Source.Version)
mu.Unlock() mu.Unlock()
} }
}) })
@ -1124,11 +1124,12 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
} }
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) { func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
t.Helper() t.Helper()
check.NotNil(t, got)
check.Equal(t, got.Source, source) check.Equal(t, got.Source, source)
check.Equal(t, got.Direction, direction) check.Equal(t, got.Direction, direction)
check.Equal(t, got.Empty, false) check.Equal(t, got.Empty, isEmpty)
check.Bool(t, got.Error == nil, true) check.Bool(t, got.Error == nil, true)
check.Bool(t, got.Duration > 0, true) check.Bool(t, got.Duration > 0, true)
} }
@ -1136,7 +1137,7 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S
func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) { func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) {
t.Helper() t.Helper()
check.Equal(t, got.Type, typ) check.Equal(t, got.Type, typ)
check.Equal(t, got.Fullpath, name) check.Equal(t, got.Path, name)
check.Equal(t, got.Version, version) check.Equal(t, got.Version, version)
switch got.Type { switch got.Type {
case provider.TypeGo: case provider.TypeGo:

View File

@ -1,96 +0,0 @@
package provider
import (
"context"
"errors"
"fmt"
"strings"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
// Short circuit if versioning is disabled and apply all migrations.
return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne)
}
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
// support the out-of-order (allow missing) feature. For users who don't use this feature, we
// could just query the database for the current version and then apply migrations that are
// greater than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
dbMaxVersion := dbMigrations[0].Version
// lookupAppliedInDB is a map of all applied migrations in the database.
lookupAppliedInDB := make(map[int64]bool)
for _, m := range dbMigrations {
lookupAppliedInDB[m.Version] = true
}
missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion)
// feature(mf): It is very possible someone may want to apply ONLY new migrations and skip
// missing migrations entirely. At the moment this is not supported, but leaving this comment
// because that's where that logic will be handled.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, v.filename)
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]",
len(missingMigrations), msg, strings.Join(collected, ","))
}
var migrationsToApply []*migration
if p.cfg.allowMissing {
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
migrationsToApply = append(migrationsToApply, m)
}
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max).
for _, m := range p.migrations {
if lookupAppliedInDB[m.Source.Version] {
continue
}
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
migrationsToApply = append(migrationsToApply, m)
}
}
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own.
// https://github.com/pressly/goose/issues/222
//
// Note, we can't use a single transaction for all migrations because some may have to be run in
// their own transaction.
return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne)
}

View File

@ -41,36 +41,23 @@ func (t MigrationType) String() string {
// Source represents a single migration source. // Source represents a single migration source.
// //
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if // The Path field may be empty if the migration was registered manually. This is typically the case
// the migration has a corresponding file on disk. It will be empty if the migration was registered // for Go migrations registered using the [WithGoMigration] option.
// manually.
type Source struct { type Source struct {
// Type is the type of migration.
Type MigrationType Type MigrationType
// Full path to the migration file. Path string
//
// Example: /path/to/migrations/001_create_users_table.sql
Fullpath string
// Version is the version of the migration.
Version int64 Version int64
} }
// MigrationResult is the result of a single migration operation. // MigrationResult is the result of a single migration operation.
//
// Note, the caller is responsible for checking the Error field for any errors that occurred while
// running the migration. If the Error field is not nil, the migration failed.
type MigrationResult struct { type MigrationResult struct {
Source Source Source Source
Duration time.Duration Duration time.Duration
Direction string Direction string
// Empty is true if the file was valid, but no statements to apply. These are still versioned // Empty indicates no action was taken during the migration, but it was still versioned. For
// migrations, but typically have no effect on the database. // SQL, it means no statements; for Go, it's a nil function.
//
// For SQL migrations, this means there was a valid .sql file but contained no statements. For
// Go migrations, this means the function was nil.
Empty bool Empty bool
// Error is only set if the migration failed.
// Error is any error that occurred while running the migration.
Error error Error error
} }
@ -78,22 +65,20 @@ type MigrationResult struct {
type State string type State string
const ( const (
// StatePending represents a migration that is on the filesystem, but not in the database. // StatePending is a migration that exists on the filesystem, but not in the database.
StatePending State = "pending" StatePending State = "pending"
// StateApplied represents a migration that is in BOTH the database and on the filesystem. // StateApplied is a migration that has been applied to the database and exists on the
// filesystem.
StateApplied State = "applied" StateApplied State = "applied"
// StateUntracked represents a migration that is in the database, but not on the filesystem. // TODO(mf): we could also add a third state for untracked migrations. This would be useful for
// StateUntracked State = "untracked" // migrations that were manually applied to the database, but not versioned. Or the Source was
// deleted, but the migration still exists in the database. StateUntracked State = "untracked"
) )
// MigrationStatus represents the status of a single migration. // MigrationStatus represents the status of a single migration.
type MigrationStatus struct { type MigrationStatus struct {
// State is the state of the migration.
State State
// AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or
// [StateUntracked].
AppliedAt time.Time
// Source is the migration source. Only set if the state is [StatePending] or [StateApplied].
Source Source Source Source
State State
AppliedAt time.Time
} }

View File

@ -61,6 +61,25 @@ func TestStore(t *testing.T) {
check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
}) })
}) })
t.Run("ListMigrations", func(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
store, err := sqladapter.NewStore("sqlite3", "foo")
check.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
check.NoError(t, err)
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1))
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3))
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2))
res, err := store.ListMigrations(context.Background(), db)
check.NoError(t, err)
check.Number(t, len(res), 3)
// Check versions are in descending order: [2, 3, 1]
check.Number(t, res[0].Version, 2)
check.Number(t, res[1].Version, 3)
check.Number(t, res[2].Version, 1)
})
} }
// testStore tests various store operations. // testStore tests various store operations.