feat(experimental): Shuffle packages and tidy up (#619)

This commit is contained in:
Michael Fridman 2023-10-25 08:56:17 -04:00 committed by GitHub
parent 3b801a60c7
commit a9da7504fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 518 additions and 378 deletions

View File

@ -8,6 +8,13 @@ import (
"testing"
)
func NotNil(t *testing.T, v any) {
t.Helper()
if v == nil {
t.Fatal("unexpected nil value")
}
}
func NoError(t *testing.T, err error) {
t.Helper()
if err != nil {

View File

@ -13,9 +13,9 @@ import (
func NewSource(t MigrationType, fullpath string, version int64) Source {
return Source{
Type: t,
Fullpath: fullpath,
Version: version,
Type: t,
Path: fullpath,
Version: version,
}
}
@ -133,7 +133,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
var unregistered []string
for _, s := range sources.goSources {
if _, ok := registerd[s.Version]; !ok {
unregistered = append(unregistered, s.Fullpath)
unregistered = append(unregistered, s.Path)
}
}
if len(unregistered) > 0 {
@ -149,7 +149,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
fullpath := r.fullpath
if fullpath == "" {
if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Fullpath
fullpath = s.Path
}
}
// Ensure there are no duplicate versions.
@ -160,7 +160,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
}
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing.Source.Fullpath,
existing.Source.Path,
fullpath,
)
}

View File

@ -6,6 +6,7 @@ import (
"testing/fstest"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
)
func TestCollectFileSources(t *testing.T) {
@ -120,13 +121,13 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1)
// 1
check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql")
check.Equal(t, sources.sqlSources[0].Path, "1_foo.sql")
check.Equal(t, sources.sqlSources[0].Version, int64(1))
// 2
check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql")
check.Equal(t, sources.sqlSources[1].Path, "5_qux.sql")
check.Equal(t, sources.sqlSources[1].Version, int64(5))
// 3
check.Equal(t, sources.goSources[0].Fullpath, "4_something.go")
check.Equal(t, sources.goSources[0].Path, "4_something.go")
check.Equal(t, sources.goSources[0].Version, int64(4))
})
t.Run("duplicate_versions", func(t *testing.T) {
@ -286,6 +287,65 @@ func TestMerge(t *testing.T) {
})
}
func TestFindMissingMigrations(t *testing.T) {
t.Parallel()
t.Run("db_has_max_version", func(t *testing.T) {
// Test case: database has migrations 1, 3, 4, 5, 7
// Missing migrations: 2, 6
// Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8
dbMigrations := []*sqladapter.ListMigrationsResult{
{Version: 1},
{Version: 3},
{Version: 4},
{Version: 5},
{Version: 7}, // <-- database max version_id
}
fsMigrations := []*migration{
newMigration(1),
newMigration(2), // missing migration
newMigration(3),
newMigration(4),
newMigration(5),
newMigration(6), // missing migration
newMigration(7), // ----- database max version_id -----
newMigration(8), // new migration
}
got := findMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 2)
check.Number(t, got[1].versionID, 6)
// Sanity check.
check.Number(t, len(findMissingMigrations(nil, nil)), 0)
check.Number(t, len(findMissingMigrations(dbMigrations, nil)), 0)
check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0)
})
t.Run("fs_has_max_version", func(t *testing.T) {
dbMigrations := []*sqladapter.ListMigrationsResult{
{Version: 1},
{Version: 5},
{Version: 2},
}
fsMigrations := []*migration{
newMigration(3), // new migration
newMigration(4), // new migration
}
got := findMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 3)
check.Number(t, got[1].versionID, 4)
})
}
func newMigration(version int64) *migration {
return &migration{
Source: Source{
Version: version,
},
}
}
func assertMigration(t *testing.T, got *migration, want Source) {
t.Helper()
check.Equal(t, got.Source, want)

View File

@ -22,18 +22,19 @@ var (
// PartialError is returned when a migration fails, but some migrations already got applied.
type PartialError struct {
// Applied are migrations that were applied successfully before the error occurred.
// Applied are migrations that were applied successfully before the error occurred. May be
// empty.
Applied []*MigrationResult
// Failed contains the result of the migration that failed.
// Failed contains the result of the migration that failed. Cannot be nil.
Failed *MigrationResult
// Err is the error that occurred while running the migration.
// Err is the error that occurred while running the migration and caused the failure.
Err error
}
func (e *PartialError) Error() string {
filename := "(file unknown)"
if e.Failed != nil && e.Failed.Source.Fullpath != "" {
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath))
if e.Failed != nil && e.Failed.Source.Path != "" {
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path))
}
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
}

View File

@ -29,7 +29,7 @@ func (m *migration) useTx(direction bool) bool {
case TypeSQL:
return m.SQL.UseTx
case TypeGo:
if m.Go == nil {
if m.Go == nil || m.Go.isEmpty(direction) {
return false
}
if direction {
@ -41,8 +41,18 @@ func (m *migration) useTx(direction bool) bool {
return false
}
func (m *migration) isEmpty(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL == nil || m.SQL.IsEmpty(direction)
case TypeGo:
return m.Go == nil || m.Go.isEmpty(direction)
}
return true
}
func (m *migration) filename() string {
return filepath.Base(m.Source.Fullpath)
return filepath.Base(m.Source.Path)
}
// run runs the migration inside of a transaction.
@ -57,7 +67,7 @@ func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
return m.Go.run(ctx, tx, direction)
}
// This should never happen.
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
// runNoTx runs the migration without a transaction.
@ -72,7 +82,7 @@ func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) err
return m.Go.runNoTx(ctx, db, direction)
}
// This should never happen.
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
// runConn runs the migration without a transaction using the provided connection.
@ -87,7 +97,7 @@ func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool)
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
}
// This should never happen.
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
type goMigration struct {
@ -95,6 +105,16 @@ type goMigration struct {
up, down *GoMigration
}
func (g *goMigration) isEmpty(direction bool) bool {
if g.up == nil && g.down == nil {
panic("go migration has no up or down")
}
if direction {
return g.up == nil
}
return g.down == nil
}
func newGoMigration(fullpath string, up, down *GoMigration) *goMigration {
return &goMigration{
fullpath: fullpath,

View File

@ -8,24 +8,48 @@ import (
)
type Migration struct {
Version int64
Source string // path to .sql script or go file
Registered bool
UseTx bool
UpFnContext func(context.Context, *sql.Tx) error
DownFnContext func(context.Context, *sql.Tx) error
UpFnNoTxContext func(context.Context, *sql.DB) error
DownFnNoTxContext func(context.Context, *sql.DB) error
Version int64
Source string // path to .sql script or go file
Registered bool
UpFnContext, DownFnContext func(context.Context, *sql.Tx) error
UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
}
var registeredGoMigrations = make(map[int64]*Migration)
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
// the migrations are nil or if a migration with the same version has already been registered.
//
// Not safe for concurrent use.
func SetGlobalGoMigrations(migrations []*Migration) error {
for _, m := range migrations {
if m == nil {
return errors.New("cannot register nil go migration")
}
if m.Version < 1 {
return errors.New("migration versions must be greater than zero")
}
if !m.Registered {
return errors.New("migration must be registered")
}
if m.Source != "" {
// If the source is set, expect it to be a file path with a numeric component that
// matches the version.
version, err := NumericComponent(m.Source)
if err != nil {
return err
}
if version != m.Version {
return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source)
}
}
// It's valid for all of these to be nil.
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
@ -34,6 +58,9 @@ func SetGlobalGoMigrations(migrations []*Migration) error {
return nil
}
// ResetGlobalGoMigrations resets the global go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalGoMigrations() {
registeredGoMigrations = make(map[int64]*Migration)
}

View File

@ -19,8 +19,9 @@ import (
// github.com/lib/pq or github.com/jackc/pgx.
//
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub.
// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
// migrations using [fs.Sub].
//
// See [ProviderOption] for more information on configuring the provider.
//
@ -39,6 +40,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
}
cfg := config{
registered: make(map[int64]*goMigration),
excludes: make(map[string]bool),
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
@ -133,10 +135,12 @@ type Provider struct {
// database.
mu sync.Mutex
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
// migrations are ordered by version in ascending order.
migrations []*migration
}
@ -149,8 +153,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
// GetDBVersion returns the max version from the database, regardless of the applied order. For
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
// applied, it returns 0.
//
// TODO(mf): this is not true?
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
return p.getDBVersion(ctx)
}
@ -175,25 +177,28 @@ func (p *Provider) Close() error {
return p.db.Close()
}
// ApplyVersion applies exactly one migration at the specified version. If there is no source for
// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been
// applied already, this method returns [ErrAlreadyApplied].
// ApplyVersion applies exactly one migration by version. If there is no source for the specified
// version, this method returns [ErrVersionNotFound]. If the migration has been applied already,
// this method returns [ErrAlreadyApplied].
//
// When direction is true, the up migration is executed, and when direction is false, the down
// migration is executed.
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.apply(ctx, version, direction)
}
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
// empty list and nil error.
// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method
// returns empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
return p.up(ctx, false, math.MaxInt64)
}
// UpByOne applies the next available migration. If there are no migrations to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) {
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
res, err := p.up(ctx, true, math.MaxInt64)
if err != nil {
return nil, err
@ -201,21 +206,28 @@ func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) {
if len(res) == 0 {
return nil, ErrNoNextVersion
}
return res, nil
// This should never happen. We should always have exactly one result and test for this.
if len(res) > 1 {
return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res))
}
return res[0], nil
}
// UpTo applies all available migrations up to and including the specified version. If there are no
// migrations to apply, this method returns empty list and nil error.
// UpTo applies all available migrations up to, and including, the specified version. If there are
// no migrations to apply, this method returns empty list and nil error.
//
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9 and 10 will be applied.
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.up(ctx, false, version)
}
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
// method returns [ErrNoNextVersion].
func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) {
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
res, err := p.down(ctx, true, 0)
if err != nil {
return nil, err
@ -223,16 +235,19 @@ func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) {
if len(res) == 0 {
return nil, ErrNoNextVersion
}
return res, nil
if len(res) > 1 {
return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res))
}
return res[0], nil
}
// DownTo rolls back all migrations down to but not including the specified version.
// DownTo rolls back all migrations down to, but not including, the specified version.
//
// For instance, if the current database version is 11, and the requested version is 9, only
// migrations 11 and 10 will be rolled back.
// For instance, if the current database version is 11,10,9... and the requested version is 9, only
// migrations 11, 10 will be rolled back.
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 0 {
return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version)
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
}
return p.down(ctx, false, version)
}

View File

@ -10,6 +10,9 @@ import (
)
const (
// DefaultTablename is the default name of the database table used to track history of applied
// migrations. It can be overridden using the [WithTableName] option when creating a new
// provider.
DefaultTablename = "goose_db_version"
)
@ -85,7 +88,8 @@ type GoMigration struct {
// WithGoMigration registers a Go migration with the given version.
//
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set.
// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
// set.
func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
return configFunc(func(c *config) error {
if version < 1 {
@ -122,9 +126,10 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
// WithAllowMissing allows the provider to apply missing (out-of-order) migrations.
//
// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is
// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied
// migrations will be: 1,6,2,3,5.
// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
// by new migrations.
func WithAllowMissing(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.allowMissing = b
@ -132,9 +137,9 @@ func WithAllowMissing(b bool) ProviderOption {
})
}
// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations
// without tracking the versions in the database schema table. Useful for tests, seeding a database
// or running ad-hoc queries.
// WithNoVersioning disables versioning. Disabling versioning allows applying migrations without
// tracking the versions in the database schema table. Useful for tests, seeding a database or
// running ad-hoc queries.
func WithNoVersioning(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.noVersioning = b

View File

@ -41,7 +41,7 @@ func TestProvider(t *testing.T) {
// Not parallel because it modifies global state.
register := []*provider.Migration{
{
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: nil,
DownFnContext: nil,
},
@ -69,32 +69,46 @@ func TestProvider(t *testing.T) {
t.Run("duplicate_up", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
{
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
},
})
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext")
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
})
t.Run("duplicate_down", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
{
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
Version: 1, Source: "00001_users_table.go", Registered: true,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
},
})
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext")
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
})
t.Run("not_registered", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
{
Version: 1, Source: "00001_users_table.go",
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration must be registered")
})
t.Run("zero_not_allowed", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
{
Version: 0,
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration versions must be greater than zero")
})
}

View File

@ -15,9 +15,153 @@ import (
"go.uber.org/multierr"
)
var (
errMissingZeroVersion = errors.New("missing zero version migration")
)
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
var apply []*migration
if p.cfg.noVersioning {
apply = p.migrations
} else {
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
// support the allow missing (out-of-order) feature. For users that don't use this feature, we
// could just query the database for the current max version and then apply migrations greater
// than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
apply, err = p.resolveUpMigrations(dbMigrations, version)
if err != nil {
return nil, err
}
}
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own.
// https://github.com/pressly/goose/issues/222
//
// Careful, we can't use a single transaction for all migrations because some may have to be run
// in their own transaction.
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
}
func (p *Provider) resolveUpMigrations(
dbVersions []*sqladapter.ListMigrationsResult,
version int64,
) ([]*migration, error) {
var apply []*migration
var dbMaxVersion int64
// dbAppliedVersions is a map of all applied migrations in the database.
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
for _, m := range dbVersions {
dbAppliedVersions[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
missingMigrations := findMissingMigrations(dbVersions, p.migrations)
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
// migrations entirely. At the moment this is not supported, but leaving this comment because
// that's where that logic would be handled.
//
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3.
// Not sure if this is a common use case, but it's possible.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, v.filename)
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
apply = append(apply, m)
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max). Skip any migrations that have already been applied.
for _, m := range p.migrations {
if dbAppliedVersions[m.Source.Version] {
continue
}
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
apply = append(apply, m)
}
}
return apply, nil
}
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
downMigrations := p.migrations
if downByOne {
last := p.migrations[len(p.migrations)-1]
downMigrations = []*migration{last}
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
if dbMigrations[0].Version == 0 {
return nil, nil
}
var downMigrations []*migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
downMigrations = append(downMigrations, m)
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
// runMigrations runs migrations sequentially in the given direction.
//
// If the migrations slice is empty, this function returns nil with no error.
// If the migrations list is empty, return nil without error.
func (p *Provider) runMigrations(
ctx context.Context,
conn *sql.Conn,
@ -28,28 +172,20 @@ func (p *Provider) runMigrations(
if len(migrations) == 0 {
return nil, nil
}
var apply []*migration
apply := migrations
if byOne {
apply = []*migration{migrations[0]}
} else {
apply = migrations
apply = migrations[:1]
}
// Lazily parse SQL migrations (if any) in both directions. We do this before running any
// migrations so that we can fail fast if there are any errors and avoid leaving the database in
// a partially migrated state.
if err := parseSQL(p.fsys, false, apply); err != nil {
return nil, err
}
// TODO(mf): If we decide to add support for advisory locks at the transaction level, this may
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
// to run in a transaction.
//
//
//
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but
// are locking the database with *sql.Conn. If the caller sets max open connections to 1, then
// this will deadlock because the Go migration will try to acquire a connection from the pool,
@ -57,31 +193,32 @@ func (p *Provider) runMigrations(
//
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to use
// *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of
// an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 {
// for _, m := range apply {
// if m.IsGo() && !m.Go.UseTx {
// return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1")
// }
// }
// }
// an edge case.
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
for _, m := range apply {
switch m.Source.Type {
case TypeGo:
if m.Go != nil && m.useTx(direction.ToBool()) {
return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1")
}
}
}
}
// Run migrations individually, opening a new transaction for each migration if the migration is
// safe to run in a transaction.
// Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the
// impression that N migrations were applied when in fact some were not 2. Avoid the caller
// having to check for nil results
// Avoid allocating a slice because we may have a partial migration error.
// 1. Avoid giving the impression that N migrations were applied when in fact some were not
// 2. Avoid the caller having to check for nil results
var results []*MigrationResult
for _, m := range apply {
current := &MigrationResult{
Source: m.Source,
Direction: strings.ToLower(direction.String()),
// TODO(mf): empty set here
Direction: direction.String(),
Empty: m.isEmpty(direction.ToBool()),
}
start := time.Now()
if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil {
// TODO(mf): we should also return the pending migrations here.
// TODO(mf): we should also return the pending migrations here, the remaining items in
// the apply slice.
current.Error = err
current.Duration = time.Since(start)
return nil, &PartialError{
@ -90,16 +227,12 @@ func (p *Provider) runMigrations(
Err: err,
}
}
current.Duration = time.Since(start)
results = append(results, current)
}
return results, nil
}
// runIndividually runs an individual migration, opening a new transaction if the migration is safe
// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the
// supplied connection.
func (p *Provider) runIndividually(
ctx context.Context,
conn *sql.Conn,
@ -182,8 +315,8 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
cleanup = func() error {
p.mu.Unlock()
// Use a detached context to unlock the session. This is because the context passed to
// SessionLock may have been canceled, and we don't want to cancel the unlock.
// TODO(mf): use [context.WithoutCancel] added in go1.21
// SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf):
// use [context.WithoutCancel] added in go1.21
detachedCtx := context.Background()
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
}
@ -206,7 +339,7 @@ func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
for _, m := range migrations {
// If the migration is a SQL migration, and it has not been parsed, parse it.
if m.Source.Type == TypeSQL && m.SQL == nil {
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug)
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug)
if err != nil {
return err
}
@ -248,11 +381,14 @@ type missingMigration struct {
func findMissingMigrations(
dbMigrations []*sqladapter.ListMigrationsResult,
fsMigrations []*migration,
dbMaxVersion int64,
) []missingMigration {
existing := make(map[int64]bool)
var dbMaxVersion int64
for _, m := range dbMigrations {
existing[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
var missing []missingMigration
for _, m := range fsMigrations {
@ -282,10 +418,6 @@ func (p *Provider) getMigration(version int64) (*migration, error) {
}
func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
m, err := p.getMigration(version)
if err != nil {
return nil, err
@ -371,5 +503,8 @@ func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) {
if len(res) == 0 {
return 0, nil
}
sort.Slice(res, func(i, j int) bool {
return res[i].Version > res[j].Version
})
return res[0].Version, nil
}

View File

@ -1,53 +0,0 @@
package provider
import (
"context"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
var downMigrations []*migration
if downByOne {
downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1])
} else {
downMigrations = p.migrations
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if dbMigrations[0].Version == 0 {
return nil, nil
}
var downMigrations []*migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
downMigrations = append(downMigrations, m)
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}

View File

@ -53,13 +53,13 @@ func TestProviderRun(t *testing.T) {
p, _ := newProviderWithDB(t)
_, err := p.UpTo(context.Background(), 0)
check.HasError(t, err)
check.Equal(t, err.Error(), "version must be greater than zero")
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
_, err = p.DownTo(context.Background(), -1)
check.HasError(t, err)
check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1")
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
_, err = p.ApplyVersion(context.Background(), 0, true)
check.HasError(t, err)
check.Equal(t, err.Error(), "version must be greater than zero")
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
})
t.Run("up_and_down_all", func(t *testing.T) {
ctx := context.Background()
@ -77,24 +77,24 @@ func TestProviderRun(t *testing.T) {
res, err := p.Up(ctx)
check.NoError(t, err)
check.Number(t, len(res), numCount)
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up")
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up")
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up")
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up")
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up")
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up")
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up")
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
// Test Down
res, err = p.DownTo(ctx, 0)
check.NoError(t, err)
check.Number(t, len(res), numCount)
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down")
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down")
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down")
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down")
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down")
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down")
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down")
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
})
t.Run("up_and_down_by_one", func(t *testing.T) {
ctx := context.Background()
@ -112,8 +112,8 @@ func TestProviderRun(t *testing.T) {
break
}
check.NoError(t, err)
check.Number(t, len(res), 1)
check.Number(t, res[0].Source.Version, int64(counter))
check.NotNil(t, res)
check.Number(t, res.Source.Version, int64(counter))
}
currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err)
@ -131,8 +131,8 @@ func TestProviderRun(t *testing.T) {
break
}
check.NoError(t, err)
check.Number(t, len(res), 1)
check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1))
check.NotNil(t, res)
check.Number(t, res.Source.Version, int64(maxVersion-counter+1))
}
// Once everything is tested the version should match the highest testdata version
currentVersion, err = p.GetDBVersion(ctx)
@ -148,8 +148,8 @@ func TestProviderRun(t *testing.T) {
results, err := p.UpTo(ctx, upToVersion)
check.NoError(t, err)
check.Number(t, len(results), upToVersion)
assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up")
assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up")
assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
// Fetch the goose version from DB
currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err)
@ -237,7 +237,11 @@ func TestProviderRun(t *testing.T) {
res, err := p.ApplyVersion(ctx, s.Version, true)
check.NoError(t, err)
// Round-trip the migration result through the database to ensure it's valid.
assertResult(t, res, s, "up")
var empty bool
if s.Version == 6 || s.Version == 7 {
empty = true
}
assertResult(t, res, s, "up", empty)
}
// Apply all migrations in the down direction.
for i := len(sources) - 1; i >= 0; i-- {
@ -245,7 +249,11 @@ func TestProviderRun(t *testing.T) {
res, err := p.ApplyVersion(ctx, s.Version, false)
check.NoError(t, err)
// Round-trip the migration result through the database to ensure it's valid.
assertResult(t, res, s, "down")
var empty bool
if s.Version == 6 || s.Version == 7 {
empty = true
}
assertResult(t, res, s, "down", empty)
}
// Try apply version 1 multiple times
_, err := p.ApplyVersion(ctx, 1, true)
@ -324,7 +332,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
// Check Results field
check.Number(t, len(expected.Applied), 1)
assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up")
assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
// Check Failed field
check.Bool(t, expected.Failed != nil, true)
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
@ -368,11 +376,11 @@ func TestConcurrentProvider(t *testing.T) {
t.Error(err)
return
}
if len(res) != 1 {
t.Errorf("expected 1 result, got %d", len(res))
if res == nil {
t.Errorf("expected non-nil result, got nil")
return
}
ch <- res[0].Source.Version
ch <- res.Source.Version
}()
}
go func() {
@ -623,13 +631,13 @@ func TestAllowMissing(t *testing.T) {
// 4
upResult, err := p.UpByOne(ctx)
check.NoError(t, err)
check.Number(t, len(upResult), 1)
check.Number(t, upResult[0].Source.Version, 4)
check.NotNil(t, upResult)
check.Number(t, upResult.Source.Version, 4)
// 6
upResult, err = p.UpByOne(ctx)
check.NoError(t, err)
check.Number(t, len(upResult), 1)
check.Number(t, upResult[0].Source.Version, 6)
check.NotNil(t, upResult)
check.Number(t, upResult.Source.Version, 6)
count, err := getGooseVersionCount(db, provider.DefaultTablename)
check.NoError(t, err)
@ -645,21 +653,29 @@ func TestAllowMissing(t *testing.T) {
// So migrating down should be the reverse of the applied order:
// 6,4,5,3,2,1
expected := []int64{6, 4, 5, 3, 2, 1}
for i, v := range expected {
// TODO(mf): this is returning it by the order it was applied.
current, err := p.GetDBVersion(ctx)
testDownAndVersion := func(wantDBVersion, wantResultVersion int64) {
currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err)
check.Number(t, current, v)
downResult, err := p.Down(ctx)
if i == len(expected)-1 {
check.HasError(t, provider.ErrVersionNotFound)
} else {
check.NoError(t, err)
check.Number(t, len(downResult), 1)
check.Number(t, downResult[0].Source.Version, v)
}
check.Number(t, currentVersion, wantDBVersion)
downRes, err := p.Down(ctx)
check.NoError(t, err)
check.NotNil(t, downRes)
check.Number(t, downRes.Source.Version, wantResultVersion)
}
// This behaviour may need to change, see the following issues for more details:
// - https://github.com/pressly/goose/issues/523
// - https://github.com/pressly/goose/issues/402
testDownAndVersion(6, 6)
testDownAndVersion(5, 4) // Ensure the max db version is 5 before down.
testDownAndVersion(5, 5)
testDownAndVersion(3, 3)
testDownAndVersion(2, 2)
testDownAndVersion(1, 1)
_, err = p.Down(ctx)
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true)
})
}
@ -674,7 +690,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
}
func TestGoOnly(t *testing.T) {
// Not parallel because it modifies global state.
// Not parallel because each subtest modifies global state.
countUser := func(db *sql.DB) int {
q := `SELECT count(*)FROM users`
@ -688,7 +704,7 @@ func TestGoOnly(t *testing.T) {
ctx := context.Background()
register := []*provider.Migration{
{
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
DownFnContext: newTxFn("DROP TABLE users"),
},
@ -713,27 +729,23 @@ func TestGoOnly(t *testing.T) {
// Apply migration 1
res, err := p.UpByOne(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up")
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2
res, err = p.UpByOne(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up")
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
check.Number(t, countUser(db), 3)
// Rollback migration 2
res, err = p.Down(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down")
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
check.Number(t, countUser(db), 0)
// Rollback migration 1
res, err = p.Down(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down")
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
// Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false)
})
@ -741,7 +753,7 @@ func TestGoOnly(t *testing.T) {
ctx := context.Background()
register := []*provider.Migration{
{
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
DownFnNoTxContext: newDBFn("DROP TABLE users"),
},
@ -766,27 +778,23 @@ func TestGoOnly(t *testing.T) {
// Apply migration 1
res, err := p.UpByOne(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up")
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2
res, err = p.UpByOne(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up")
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
check.Number(t, countUser(db), 3)
// Rollback migration 2
res, err = p.Down(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down")
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
check.Number(t, countUser(db), 0)
// Rollback migration 1
res, err = p.Down(ctx)
check.NoError(t, err)
check.Number(t, len(res), 1)
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down")
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
// Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false)
})
@ -880,7 +888,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
)
g.Go(func() error {
for {
results, err := provider1.UpByOne(context.Background())
result, err := provider1.UpByOne(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
return nil
@ -888,17 +896,15 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err
}
check.NoError(t, err)
if len(results) != 1 {
return fmt.Errorf("expected 1 result, got %d", len(results))
}
check.NotNil(t, result)
mu.Lock()
applied = append(applied, results[0].Source.Version)
applied = append(applied, result.Source.Version)
mu.Unlock()
}
})
g.Go(func() error {
for {
results, err := provider2.UpByOne(context.Background())
result, err := provider2.UpByOne(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
return nil
@ -906,11 +912,9 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err
}
check.NoError(t, err)
if len(results) != 1 {
return fmt.Errorf("expected 1 result, got %d", len(results))
}
check.NotNil(t, result)
mu.Lock()
applied = append(applied, results[0].Source.Version)
applied = append(applied, result.Source.Version)
mu.Unlock()
}
})
@ -986,37 +990,33 @@ func TestLockModeAdvisorySession(t *testing.T) {
)
g.Go(func() error {
for {
results, err := provider1.Down(context.Background())
result, err := provider1.Down(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
return nil
}
return err
}
if len(results) != 1 {
return fmt.Errorf("expected 1 result, got %d", len(results))
}
check.NoError(t, err)
check.NotNil(t, result)
mu.Lock()
applied = append(applied, results[0].Source.Version)
applied = append(applied, result.Source.Version)
mu.Unlock()
}
})
g.Go(func() error {
for {
results, err := provider2.Down(context.Background())
result, err := provider2.Down(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
return nil
}
return err
}
if len(results) != 1 {
return fmt.Errorf("expected 1 result, got %d", len(results))
}
check.NoError(t, err)
check.NotNil(t, result)
mu.Lock()
applied = append(applied, results[0].Source.Version)
applied = append(applied, result.Source.Version)
mu.Unlock()
}
})
@ -1124,11 +1124,12 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
}
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) {
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
t.Helper()
check.NotNil(t, got)
check.Equal(t, got.Source, source)
check.Equal(t, got.Direction, direction)
check.Equal(t, got.Empty, false)
check.Equal(t, got.Empty, isEmpty)
check.Bool(t, got.Error == nil, true)
check.Bool(t, got.Duration > 0, true)
}
@ -1136,7 +1137,7 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S
func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) {
t.Helper()
check.Equal(t, got.Type, typ)
check.Equal(t, got.Fullpath, name)
check.Equal(t, got.Path, name)
check.Equal(t, got.Version, version)
switch got.Type {
case provider.TypeGo:

View File

@ -1,96 +0,0 @@
package provider
import (
"context"
"errors"
"fmt"
"strings"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
// Short circuit if versioning is disabled and apply all migrations.
return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne)
}
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
// support the out-of-order (allow missing) feature. For users who don't use this feature, we
// could just query the database for the current version and then apply migrations that are
// greater than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
dbMaxVersion := dbMigrations[0].Version
// lookupAppliedInDB is a map of all applied migrations in the database.
lookupAppliedInDB := make(map[int64]bool)
for _, m := range dbMigrations {
lookupAppliedInDB[m.Version] = true
}
missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion)
// feature(mf): It is very possible someone may want to apply ONLY new migrations and skip
// missing migrations entirely. At the moment this is not supported, but leaving this comment
// because that's where that logic will be handled.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, v.filename)
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]",
len(missingMigrations), msg, strings.Join(collected, ","))
}
var migrationsToApply []*migration
if p.cfg.allowMissing {
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
migrationsToApply = append(migrationsToApply, m)
}
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max).
for _, m := range p.migrations {
if lookupAppliedInDB[m.Source.Version] {
continue
}
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
migrationsToApply = append(migrationsToApply, m)
}
}
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own.
// https://github.com/pressly/goose/issues/222
//
// Note, we can't use a single transaction for all migrations because some may have to be run in
// their own transaction.
return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne)
}

View File

@ -41,36 +41,23 @@ func (t MigrationType) String() string {
// Source represents a single migration source.
//
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
// the migration has a corresponding file on disk. It will be empty if the migration was registered
// manually.
// The Path field may be empty if the migration was registered manually. This is typically the case
// for Go migrations registered using the [WithGoMigration] option.
type Source struct {
// Type is the type of migration.
Type MigrationType
// Full path to the migration file.
//
// Example: /path/to/migrations/001_create_users_table.sql
Fullpath string
// Version is the version of the migration.
Type MigrationType
Path string
Version int64
}
// MigrationResult is the result of a single migration operation.
//
// Note, the caller is responsible for checking the Error field for any errors that occurred while
// running the migration. If the Error field is not nil, the migration failed.
type MigrationResult struct {
Source Source
Duration time.Duration
Direction string
// Empty is true if the file was valid, but no statements to apply. These are still versioned
// migrations, but typically have no effect on the database.
//
// For SQL migrations, this means there was a valid .sql file but contained no statements. For
// Go migrations, this means the function was nil.
// Empty indicates no action was taken during the migration, but it was still versioned. For
// SQL, it means no statements; for Go, it's a nil function.
Empty bool
// Error is any error that occurred while running the migration.
// Error is only set if the migration failed.
Error error
}
@ -78,22 +65,20 @@ type MigrationResult struct {
type State string
const (
// StatePending represents a migration that is on the filesystem, but not in the database.
// StatePending is a migration that exists on the filesystem, but not in the database.
StatePending State = "pending"
// StateApplied represents a migration that is in BOTH the database and on the filesystem.
// StateApplied is a migration that has been applied to the database and exists on the
// filesystem.
StateApplied State = "applied"
// StateUntracked represents a migration that is in the database, but not on the filesystem.
// StateUntracked State = "untracked"
// TODO(mf): we could also add a third state for untracked migrations. This would be useful for
// migrations that were manually applied to the database, but not versioned. Or the Source was
// deleted, but the migration still exists in the database. StateUntracked State = "untracked"
)
// MigrationStatus represents the status of a single migration.
type MigrationStatus struct {
// State is the state of the migration.
State State
// AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or
// [StateUntracked].
Source Source
State State
AppliedAt time.Time
// Source is the migration source. Only set if the state is [StatePending] or [StateApplied].
Source Source
}

View File

@ -61,6 +61,25 @@ func TestStore(t *testing.T) {
check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
})
})
t.Run("ListMigrations", func(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
store, err := sqladapter.NewStore("sqlite3", "foo")
check.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
check.NoError(t, err)
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1))
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3))
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2))
res, err := store.ListMigrations(context.Background(), db)
check.NoError(t, err)
check.Number(t, len(res), 3)
// Check versions are in descending order: [2, 3, 1]
check.Number(t, res[0].Version, 2)
check.Number(t, res[1].Version, 3)
check.Number(t, res[2].Version, 1)
})
}
// testStore tests various store operations.