mirror of
https://github.com/pressly/goose.git
synced 2025-05-28 02:03:08 +00:00
feat(experimental): Shuffle packages and tidy up (#619)
This commit is contained in:
parent
3b801a60c7
commit
a9da7504fa
@ -8,6 +8,13 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func NotNil(t *testing.T, v any) {
|
||||
t.Helper()
|
||||
if v == nil {
|
||||
t.Fatal("unexpected nil value")
|
||||
}
|
||||
}
|
||||
|
||||
func NoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
|
@ -13,9 +13,9 @@ import (
|
||||
|
||||
func NewSource(t MigrationType, fullpath string, version int64) Source {
|
||||
return Source{
|
||||
Type: t,
|
||||
Fullpath: fullpath,
|
||||
Version: version,
|
||||
Type: t,
|
||||
Path: fullpath,
|
||||
Version: version,
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
|
||||
var unregistered []string
|
||||
for _, s := range sources.goSources {
|
||||
if _, ok := registerd[s.Version]; !ok {
|
||||
unregistered = append(unregistered, s.Fullpath)
|
||||
unregistered = append(unregistered, s.Path)
|
||||
}
|
||||
}
|
||||
if len(unregistered) > 0 {
|
||||
@ -149,7 +149,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
|
||||
fullpath := r.fullpath
|
||||
if fullpath == "" {
|
||||
if s := sources.lookup(TypeGo, version); s != nil {
|
||||
fullpath = s.Fullpath
|
||||
fullpath = s.Path
|
||||
}
|
||||
}
|
||||
// Ensure there are no duplicate versions.
|
||||
@ -160,7 +160,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
|
||||
}
|
||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||
version,
|
||||
existing.Source.Fullpath,
|
||||
existing.Source.Path,
|
||||
fullpath,
|
||||
)
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/sqladapter"
|
||||
)
|
||||
|
||||
func TestCollectFileSources(t *testing.T) {
|
||||
@ -120,13 +121,13 @@ func TestCollectFileSources(t *testing.T) {
|
||||
check.Number(t, len(sources.sqlSources), 2)
|
||||
check.Number(t, len(sources.goSources), 1)
|
||||
// 1
|
||||
check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql")
|
||||
check.Equal(t, sources.sqlSources[0].Path, "1_foo.sql")
|
||||
check.Equal(t, sources.sqlSources[0].Version, int64(1))
|
||||
// 2
|
||||
check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql")
|
||||
check.Equal(t, sources.sqlSources[1].Path, "5_qux.sql")
|
||||
check.Equal(t, sources.sqlSources[1].Version, int64(5))
|
||||
// 3
|
||||
check.Equal(t, sources.goSources[0].Fullpath, "4_something.go")
|
||||
check.Equal(t, sources.goSources[0].Path, "4_something.go")
|
||||
check.Equal(t, sources.goSources[0].Version, int64(4))
|
||||
})
|
||||
t.Run("duplicate_versions", func(t *testing.T) {
|
||||
@ -286,6 +287,65 @@ func TestMerge(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindMissingMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("db_has_max_version", func(t *testing.T) {
|
||||
// Test case: database has migrations 1, 3, 4, 5, 7
|
||||
// Missing migrations: 2, 6
|
||||
// Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8
|
||||
dbMigrations := []*sqladapter.ListMigrationsResult{
|
||||
{Version: 1},
|
||||
{Version: 3},
|
||||
{Version: 4},
|
||||
{Version: 5},
|
||||
{Version: 7}, // <-- database max version_id
|
||||
}
|
||||
fsMigrations := []*migration{
|
||||
newMigration(1),
|
||||
newMigration(2), // missing migration
|
||||
newMigration(3),
|
||||
newMigration(4),
|
||||
newMigration(5),
|
||||
newMigration(6), // missing migration
|
||||
newMigration(7), // ----- database max version_id -----
|
||||
newMigration(8), // new migration
|
||||
}
|
||||
got := findMissingMigrations(dbMigrations, fsMigrations)
|
||||
check.Number(t, len(got), 2)
|
||||
check.Number(t, got[0].versionID, 2)
|
||||
check.Number(t, got[1].versionID, 6)
|
||||
|
||||
// Sanity check.
|
||||
check.Number(t, len(findMissingMigrations(nil, nil)), 0)
|
||||
check.Number(t, len(findMissingMigrations(dbMigrations, nil)), 0)
|
||||
check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0)
|
||||
})
|
||||
t.Run("fs_has_max_version", func(t *testing.T) {
|
||||
dbMigrations := []*sqladapter.ListMigrationsResult{
|
||||
{Version: 1},
|
||||
{Version: 5},
|
||||
{Version: 2},
|
||||
}
|
||||
fsMigrations := []*migration{
|
||||
newMigration(3), // new migration
|
||||
newMigration(4), // new migration
|
||||
}
|
||||
got := findMissingMigrations(dbMigrations, fsMigrations)
|
||||
check.Number(t, len(got), 2)
|
||||
check.Number(t, got[0].versionID, 3)
|
||||
check.Number(t, got[1].versionID, 4)
|
||||
})
|
||||
}
|
||||
|
||||
func newMigration(version int64) *migration {
|
||||
return &migration{
|
||||
Source: Source{
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func assertMigration(t *testing.T, got *migration, want Source) {
|
||||
t.Helper()
|
||||
check.Equal(t, got.Source, want)
|
||||
|
@ -22,18 +22,19 @@ var (
|
||||
|
||||
// PartialError is returned when a migration fails, but some migrations already got applied.
|
||||
type PartialError struct {
|
||||
// Applied are migrations that were applied successfully before the error occurred.
|
||||
// Applied are migrations that were applied successfully before the error occurred. May be
|
||||
// empty.
|
||||
Applied []*MigrationResult
|
||||
// Failed contains the result of the migration that failed.
|
||||
// Failed contains the result of the migration that failed. Cannot be nil.
|
||||
Failed *MigrationResult
|
||||
// Err is the error that occurred while running the migration.
|
||||
// Err is the error that occurred while running the migration and caused the failure.
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *PartialError) Error() string {
|
||||
filename := "(file unknown)"
|
||||
if e.Failed != nil && e.Failed.Source.Fullpath != "" {
|
||||
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath))
|
||||
if e.Failed != nil && e.Failed.Source.Path != "" {
|
||||
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path))
|
||||
}
|
||||
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func (m *migration) useTx(direction bool) bool {
|
||||
case TypeSQL:
|
||||
return m.SQL.UseTx
|
||||
case TypeGo:
|
||||
if m.Go == nil {
|
||||
if m.Go == nil || m.Go.isEmpty(direction) {
|
||||
return false
|
||||
}
|
||||
if direction {
|
||||
@ -41,8 +41,18 @@ func (m *migration) useTx(direction bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *migration) isEmpty(direction bool) bool {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
return m.SQL == nil || m.SQL.IsEmpty(direction)
|
||||
case TypeGo:
|
||||
return m.Go == nil || m.Go.isEmpty(direction)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *migration) filename() string {
|
||||
return filepath.Base(m.Source.Fullpath)
|
||||
return filepath.Base(m.Source.Path)
|
||||
}
|
||||
|
||||
// run runs the migration inside of a transaction.
|
||||
@ -57,7 +67,7 @@ func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
return m.Go.run(ctx, tx, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
|
||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
||||
}
|
||||
|
||||
// runNoTx runs the migration without a transaction.
|
||||
@ -72,7 +82,7 @@ func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) err
|
||||
return m.Go.runNoTx(ctx, db, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
|
||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
||||
}
|
||||
|
||||
// runConn runs the migration without a transaction using the provided connection.
|
||||
@ -87,7 +97,7 @@ func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool)
|
||||
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
|
||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
||||
}
|
||||
|
||||
type goMigration struct {
|
||||
@ -95,6 +105,16 @@ type goMigration struct {
|
||||
up, down *GoMigration
|
||||
}
|
||||
|
||||
func (g *goMigration) isEmpty(direction bool) bool {
|
||||
if g.up == nil && g.down == nil {
|
||||
panic("go migration has no up or down")
|
||||
}
|
||||
if direction {
|
||||
return g.up == nil
|
||||
}
|
||||
return g.down == nil
|
||||
}
|
||||
|
||||
func newGoMigration(fullpath string, up, down *GoMigration) *goMigration {
|
||||
return &goMigration{
|
||||
fullpath: fullpath,
|
||||
|
@ -8,24 +8,48 @@ import (
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
Version int64
|
||||
Source string // path to .sql script or go file
|
||||
Registered bool
|
||||
UseTx bool
|
||||
UpFnContext func(context.Context, *sql.Tx) error
|
||||
DownFnContext func(context.Context, *sql.Tx) error
|
||||
|
||||
UpFnNoTxContext func(context.Context, *sql.DB) error
|
||||
DownFnNoTxContext func(context.Context, *sql.DB) error
|
||||
Version int64
|
||||
Source string // path to .sql script or go file
|
||||
Registered bool
|
||||
UpFnContext, DownFnContext func(context.Context, *sql.Tx) error
|
||||
UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
|
||||
}
|
||||
|
||||
var registeredGoMigrations = make(map[int64]*Migration)
|
||||
|
||||
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
|
||||
// the migrations are nil or if a migration with the same version has already been registered.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func SetGlobalGoMigrations(migrations []*Migration) error {
|
||||
for _, m := range migrations {
|
||||
if m == nil {
|
||||
return errors.New("cannot register nil go migration")
|
||||
}
|
||||
if m.Version < 1 {
|
||||
return errors.New("migration versions must be greater than zero")
|
||||
}
|
||||
if !m.Registered {
|
||||
return errors.New("migration must be registered")
|
||||
}
|
||||
if m.Source != "" {
|
||||
// If the source is set, expect it to be a file path with a numeric component that
|
||||
// matches the version.
|
||||
version, err := NumericComponent(m.Source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if version != m.Version {
|
||||
return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source)
|
||||
}
|
||||
}
|
||||
// It's valid for all of these to be nil.
|
||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
}
|
||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
||||
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
}
|
||||
if _, ok := registeredGoMigrations[m.Version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||
}
|
||||
@ -34,6 +58,9 @@ func SetGlobalGoMigrations(migrations []*Migration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetGlobalGoMigrations resets the global go migrations registry.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func ResetGlobalGoMigrations() {
|
||||
registeredGoMigrations = make(map[int64]*Migration)
|
||||
}
|
||||
|
@ -19,8 +19,9 @@ import (
|
||||
// github.com/lib/pq or github.com/jackc/pgx.
|
||||
//
|
||||
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
|
||||
// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
|
||||
// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub.
|
||||
// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
|
||||
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
|
||||
// migrations using [fs.Sub].
|
||||
//
|
||||
// See [ProviderOption] for more information on configuring the provider.
|
||||
//
|
||||
@ -39,6 +40,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
|
||||
}
|
||||
cfg := config{
|
||||
registered: make(map[int64]*goMigration),
|
||||
excludes: make(map[string]bool),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
@ -133,10 +135,12 @@ type Provider struct {
|
||||
// database.
|
||||
mu sync.Mutex
|
||||
|
||||
db *sql.DB
|
||||
fsys fs.FS
|
||||
cfg config
|
||||
store sqladapter.Store
|
||||
db *sql.DB
|
||||
fsys fs.FS
|
||||
cfg config
|
||||
store sqladapter.Store
|
||||
|
||||
// migrations are ordered by version in ascending order.
|
||||
migrations []*migration
|
||||
}
|
||||
|
||||
@ -149,8 +153,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
||||
// GetDBVersion returns the max version from the database, regardless of the applied order. For
|
||||
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
|
||||
// applied, it returns 0.
|
||||
//
|
||||
// TODO(mf): this is not true?
|
||||
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
||||
return p.getDBVersion(ctx)
|
||||
}
|
||||
@ -175,25 +177,28 @@ func (p *Provider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// ApplyVersion applies exactly one migration at the specified version. If there is no source for
|
||||
// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been
|
||||
// applied already, this method returns [ErrAlreadyApplied].
|
||||
// ApplyVersion applies exactly one migration by version. If there is no source for the specified
|
||||
// version, this method returns [ErrVersionNotFound]. If the migration has been applied already,
|
||||
// this method returns [ErrAlreadyApplied].
|
||||
//
|
||||
// When direction is true, the up migration is executed, and when direction is false, the down
|
||||
// migration is executed.
|
||||
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
|
||||
if version < 1 {
|
||||
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
|
||||
}
|
||||
return p.apply(ctx, version, direction)
|
||||
}
|
||||
|
||||
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
|
||||
// empty list and nil error.
|
||||
// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method
|
||||
// returns empty list and nil error.
|
||||
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
|
||||
return p.up(ctx, false, math.MaxInt64)
|
||||
}
|
||||
|
||||
// UpByOne applies the next available migration. If there are no migrations to apply, this method
|
||||
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
|
||||
func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) {
|
||||
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
||||
res, err := p.up(ctx, true, math.MaxInt64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -201,21 +206,28 @@ func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) {
|
||||
if len(res) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
return res, nil
|
||||
// This should never happen. We should always have exactly one result and test for this.
|
||||
if len(res) > 1 {
|
||||
return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res))
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// UpTo applies all available migrations up to and including the specified version. If there are no
|
||||
// migrations to apply, this method returns empty list and nil error.
|
||||
// UpTo applies all available migrations up to, and including, the specified version. If there are
|
||||
// no migrations to apply, this method returns empty list and nil error.
|
||||
//
|
||||
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
|
||||
// with a requested version of 10, only versions 9 and 10 will be applied.
|
||||
// with a requested version of 10, only versions 9,10 will be applied.
|
||||
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
if version < 1 {
|
||||
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
|
||||
}
|
||||
return p.up(ctx, false, version)
|
||||
}
|
||||
|
||||
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
|
||||
// method returns [ErrNoNextVersion].
|
||||
func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) {
|
||||
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
||||
res, err := p.down(ctx, true, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -223,16 +235,19 @@ func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) {
|
||||
if len(res) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
return res, nil
|
||||
if len(res) > 1 {
|
||||
return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res))
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// DownTo rolls back all migrations down to but not including the specified version.
|
||||
// DownTo rolls back all migrations down to, but not including, the specified version.
|
||||
//
|
||||
// For instance, if the current database version is 11, and the requested version is 9, only
|
||||
// migrations 11 and 10 will be rolled back.
|
||||
// For instance, if the current database version is 11,10,9... and the requested version is 9, only
|
||||
// migrations 11, 10 will be rolled back.
|
||||
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
if version < 0 {
|
||||
return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version)
|
||||
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
|
||||
}
|
||||
return p.down(ctx, false, version)
|
||||
}
|
||||
|
@ -10,6 +10,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTablename is the default name of the database table used to track history of applied
|
||||
// migrations. It can be overridden using the [WithTableName] option when creating a new
|
||||
// provider.
|
||||
DefaultTablename = "goose_db_version"
|
||||
)
|
||||
|
||||
@ -85,7 +88,8 @@ type GoMigration struct {
|
||||
// WithGoMigration registers a Go migration with the given version.
|
||||
//
|
||||
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
|
||||
// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set.
|
||||
// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
|
||||
// set.
|
||||
func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
if version < 1 {
|
||||
@ -122,9 +126,10 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
|
||||
|
||||
// WithAllowMissing allows the provider to apply missing (out-of-order) migrations.
|
||||
//
|
||||
// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is
|
||||
// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied
|
||||
// migrations will be: 1,6,2,3,5.
|
||||
// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
|
||||
// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
|
||||
// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
|
||||
// by new migrations.
|
||||
func WithAllowMissing(b bool) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
c.allowMissing = b
|
||||
@ -132,9 +137,9 @@ func WithAllowMissing(b bool) ProviderOption {
|
||||
})
|
||||
}
|
||||
|
||||
// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations
|
||||
// without tracking the versions in the database schema table. Useful for tests, seeding a database
|
||||
// or running ad-hoc queries.
|
||||
// WithNoVersioning disables versioning. Disabling versioning allows applying migrations without
|
||||
// tracking the versions in the database schema table. Useful for tests, seeding a database or
|
||||
// running ad-hoc queries.
|
||||
func WithNoVersioning(b bool) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
c.noVersioning = b
|
||||
|
@ -41,7 +41,7 @@ func TestProvider(t *testing.T) {
|
||||
// Not parallel because it modifies global state.
|
||||
register := []*provider.Migration{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnContext: nil,
|
||||
DownFnContext: nil,
|
||||
},
|
||||
@ -69,32 +69,46 @@ func TestProvider(t *testing.T) {
|
||||
t.Run("duplicate_up", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.Migration{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
|
||||
},
|
||||
})
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
db := newDB(t)
|
||||
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext")
|
||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
})
|
||||
t.Run("duplicate_down", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.Migration{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
|
||||
},
|
||||
})
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
db := newDB(t)
|
||||
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext")
|
||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
})
|
||||
t.Run("not_registered", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.Migration{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go",
|
||||
},
|
||||
})
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "migration must be registered")
|
||||
})
|
||||
t.Run("zero_not_allowed", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.Migration{
|
||||
{
|
||||
Version: 0,
|
||||
},
|
||||
})
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "migration versions must be greater than zero")
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -15,9 +15,153 @@ import (
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingZeroVersion = errors.New("missing zero version migration")
|
||||
)
|
||||
|
||||
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
||||
if version < 1 {
|
||||
return nil, errors.New("version must be greater than zero")
|
||||
}
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var apply []*migration
|
||||
if p.cfg.noVersioning {
|
||||
apply = p.migrations
|
||||
} else {
|
||||
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
|
||||
// support the allow missing (out-of-order) feature. For users that don't use this feature, we
|
||||
// could just query the database for the current max version and then apply migrations greater
|
||||
// than that version.
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbMigrations) == 0 {
|
||||
return nil, errMissingZeroVersion
|
||||
}
|
||||
apply, err = p.resolveUpMigrations(dbMigrations, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
|
||||
// transaction. The default is to apply each migration sequentially on its own.
|
||||
// https://github.com/pressly/goose/issues/222
|
||||
//
|
||||
// Careful, we can't use a single transaction for all migrations because some may have to be run
|
||||
// in their own transaction.
|
||||
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
|
||||
}
|
||||
|
||||
func (p *Provider) resolveUpMigrations(
|
||||
dbVersions []*sqladapter.ListMigrationsResult,
|
||||
version int64,
|
||||
) ([]*migration, error) {
|
||||
var apply []*migration
|
||||
var dbMaxVersion int64
|
||||
// dbAppliedVersions is a map of all applied migrations in the database.
|
||||
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
|
||||
for _, m := range dbVersions {
|
||||
dbAppliedVersions[m.Version] = true
|
||||
if m.Version > dbMaxVersion {
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
missingMigrations := findMissingMigrations(dbVersions, p.migrations)
|
||||
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
|
||||
// migrations entirely. At the moment this is not supported, but leaving this comment because
|
||||
// that's where that logic would be handled.
|
||||
//
|
||||
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3.
|
||||
// Not sure if this is a common use case, but it's possible.
|
||||
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
||||
var collected []string
|
||||
for _, v := range missingMigrations {
|
||||
collected = append(collected, v.filename)
|
||||
}
|
||||
msg := "migration"
|
||||
if len(collected) > 1 {
|
||||
msg += "s"
|
||||
}
|
||||
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
|
||||
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
|
||||
)
|
||||
}
|
||||
for _, v := range missingMigrations {
|
||||
m, err := p.getMigration(v.versionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apply = append(apply, m)
|
||||
}
|
||||
// filter all migrations with a version greater than the supplied version (min) and less than or
|
||||
// equal to the requested version (max). Skip any migrations that have already been applied.
|
||||
for _, m := range p.migrations {
|
||||
if dbAppliedVersions[m.Source.Version] {
|
||||
continue
|
||||
}
|
||||
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
|
||||
apply = append(apply, m)
|
||||
}
|
||||
}
|
||||
return apply, nil
|
||||
}
|
||||
|
||||
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if p.cfg.noVersioning {
|
||||
downMigrations := p.migrations
|
||||
if downByOne {
|
||||
last := p.migrations[len(p.migrations)-1]
|
||||
downMigrations = []*migration{last}
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
||||
}
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbMigrations) == 0 {
|
||||
return nil, errMissingZeroVersion
|
||||
}
|
||||
if dbMigrations[0].Version == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var downMigrations []*migration
|
||||
for _, dbMigration := range dbMigrations {
|
||||
if dbMigration.Version <= version {
|
||||
break
|
||||
}
|
||||
m, err := p.getMigration(dbMigration.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
downMigrations = append(downMigrations, m)
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
||||
}
|
||||
|
||||
// runMigrations runs migrations sequentially in the given direction.
|
||||
//
|
||||
// If the migrations slice is empty, this function returns nil with no error.
|
||||
// If the migrations list is empty, return nil without error.
|
||||
func (p *Provider) runMigrations(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
@ -28,28 +172,20 @@ func (p *Provider) runMigrations(
|
||||
if len(migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var apply []*migration
|
||||
apply := migrations
|
||||
if byOne {
|
||||
apply = []*migration{migrations[0]}
|
||||
} else {
|
||||
apply = migrations
|
||||
apply = migrations[:1]
|
||||
}
|
||||
// Lazily parse SQL migrations (if any) in both directions. We do this before running any
|
||||
// migrations so that we can fail fast if there are any errors and avoid leaving the database in
|
||||
// a partially migrated state.
|
||||
|
||||
if err := parseSQL(p.fsys, false, apply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(mf): If we decide to add support for advisory locks at the transaction level, this may
|
||||
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
|
||||
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
|
||||
// to run in a transaction.
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but
|
||||
// are locking the database with *sql.Conn. If the caller sets max open connections to 1, then
|
||||
// this will deadlock because the Go migration will try to acquire a connection from the pool,
|
||||
@ -57,31 +193,32 @@ func (p *Provider) runMigrations(
|
||||
//
|
||||
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to use
|
||||
// *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of
|
||||
// an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 {
|
||||
// for _, m := range apply {
|
||||
// if m.IsGo() && !m.Go.UseTx {
|
||||
// return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// an edge case.
|
||||
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
|
||||
for _, m := range apply {
|
||||
switch m.Source.Type {
|
||||
case TypeGo:
|
||||
if m.Go != nil && m.useTx(direction.ToBool()) {
|
||||
return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run migrations individually, opening a new transaction for each migration if the migration is
|
||||
// safe to run in a transaction.
|
||||
|
||||
// Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the
|
||||
// impression that N migrations were applied when in fact some were not 2. Avoid the caller
|
||||
// having to check for nil results
|
||||
// Avoid allocating a slice because we may have a partial migration error.
|
||||
// 1. Avoid giving the impression that N migrations were applied when in fact some were not
|
||||
// 2. Avoid the caller having to check for nil results
|
||||
var results []*MigrationResult
|
||||
for _, m := range apply {
|
||||
current := &MigrationResult{
|
||||
Source: m.Source,
|
||||
Direction: strings.ToLower(direction.String()),
|
||||
// TODO(mf): empty set here
|
||||
Direction: direction.String(),
|
||||
Empty: m.isEmpty(direction.ToBool()),
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil {
|
||||
// TODO(mf): we should also return the pending migrations here.
|
||||
// TODO(mf): we should also return the pending migrations here, the remaining items in
|
||||
// the apply slice.
|
||||
current.Error = err
|
||||
current.Duration = time.Since(start)
|
||||
return nil, &PartialError{
|
||||
@ -90,16 +227,12 @@ func (p *Provider) runMigrations(
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
current.Duration = time.Since(start)
|
||||
results = append(results, current)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// runIndividually runs an individual migration, opening a new transaction if the migration is safe
|
||||
// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the
|
||||
// supplied connection.
|
||||
func (p *Provider) runIndividually(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
@ -182,8 +315,8 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
|
||||
cleanup = func() error {
|
||||
p.mu.Unlock()
|
||||
// Use a detached context to unlock the session. This is because the context passed to
|
||||
// SessionLock may have been canceled, and we don't want to cancel the unlock.
|
||||
// TODO(mf): use [context.WithoutCancel] added in go1.21
|
||||
// SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf):
|
||||
// use [context.WithoutCancel] added in go1.21
|
||||
detachedCtx := context.Background()
|
||||
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
|
||||
}
|
||||
@ -206,7 +339,7 @@ func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
|
||||
for _, m := range migrations {
|
||||
// If the migration is a SQL migration, and it has not been parsed, parse it.
|
||||
if m.Source.Type == TypeSQL && m.SQL == nil {
|
||||
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug)
|
||||
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -248,11 +381,14 @@ type missingMigration struct {
|
||||
func findMissingMigrations(
|
||||
dbMigrations []*sqladapter.ListMigrationsResult,
|
||||
fsMigrations []*migration,
|
||||
dbMaxVersion int64,
|
||||
) []missingMigration {
|
||||
existing := make(map[int64]bool)
|
||||
var dbMaxVersion int64
|
||||
for _, m := range dbMigrations {
|
||||
existing[m.Version] = true
|
||||
if m.Version > dbMaxVersion {
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
var missing []missingMigration
|
||||
for _, m := range fsMigrations {
|
||||
@ -282,10 +418,6 @@ func (p *Provider) getMigration(version int64) (*migration, error) {
|
||||
}
|
||||
|
||||
func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
|
||||
if version < 1 {
|
||||
return nil, errors.New("version must be greater than zero")
|
||||
}
|
||||
|
||||
m, err := p.getMigration(version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -371,5 +503,8 @@ func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) {
|
||||
if len(res) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
sort.Slice(res, func(i, j int) bool {
|
||||
return res[i].Version > res[j].Version
|
||||
})
|
||||
return res[0].Version, nil
|
||||
}
|
||||
|
@ -1,53 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.cfg.noVersioning {
|
||||
var downMigrations []*migration
|
||||
if downByOne {
|
||||
downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1])
|
||||
} else {
|
||||
downMigrations = p.migrations
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
||||
}
|
||||
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dbMigrations[0].Version == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var downMigrations []*migration
|
||||
for _, dbMigration := range dbMigrations {
|
||||
if dbMigration.Version <= version {
|
||||
break
|
||||
}
|
||||
m, err := p.getMigration(dbMigration.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
downMigrations = append(downMigrations, m)
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
||||
}
|
@ -53,13 +53,13 @@ func TestProviderRun(t *testing.T) {
|
||||
p, _ := newProviderWithDB(t)
|
||||
_, err := p.UpTo(context.Background(), 0)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "version must be greater than zero")
|
||||
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
|
||||
_, err = p.DownTo(context.Background(), -1)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1")
|
||||
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
|
||||
_, err = p.ApplyVersion(context.Background(), 0, true)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "version must be greater than zero")
|
||||
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
|
||||
})
|
||||
t.Run("up_and_down_all", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@ -77,24 +77,24 @@ func TestProviderRun(t *testing.T) {
|
||||
res, err := p.Up(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), numCount)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up")
|
||||
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up")
|
||||
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up")
|
||||
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up")
|
||||
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up")
|
||||
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up")
|
||||
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up")
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
||||
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
|
||||
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
|
||||
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
|
||||
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
||||
// Test Down
|
||||
res, err = p.DownTo(ctx, 0)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), numCount)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down")
|
||||
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down")
|
||||
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down")
|
||||
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down")
|
||||
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down")
|
||||
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down")
|
||||
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down")
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
||||
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
||||
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
||||
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
|
||||
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
|
||||
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
|
||||
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
|
||||
})
|
||||
t.Run("up_and_down_by_one", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@ -112,8 +112,8 @@ func TestProviderRun(t *testing.T) {
|
||||
break
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
check.Number(t, res[0].Source.Version, int64(counter))
|
||||
check.NotNil(t, res)
|
||||
check.Number(t, res.Source.Version, int64(counter))
|
||||
}
|
||||
currentVersion, err := p.GetDBVersion(ctx)
|
||||
check.NoError(t, err)
|
||||
@ -131,8 +131,8 @@ func TestProviderRun(t *testing.T) {
|
||||
break
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1))
|
||||
check.NotNil(t, res)
|
||||
check.Number(t, res.Source.Version, int64(maxVersion-counter+1))
|
||||
}
|
||||
// Once everything is tested the version should match the highest testdata version
|
||||
currentVersion, err = p.GetDBVersion(ctx)
|
||||
@ -148,8 +148,8 @@ func TestProviderRun(t *testing.T) {
|
||||
results, err := p.UpTo(ctx, upToVersion)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(results), upToVersion)
|
||||
assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up")
|
||||
assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up")
|
||||
assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||
// Fetch the goose version from DB
|
||||
currentVersion, err := p.GetDBVersion(ctx)
|
||||
check.NoError(t, err)
|
||||
@ -237,7 +237,11 @@ func TestProviderRun(t *testing.T) {
|
||||
res, err := p.ApplyVersion(ctx, s.Version, true)
|
||||
check.NoError(t, err)
|
||||
// Round-trip the migration result through the database to ensure it's valid.
|
||||
assertResult(t, res, s, "up")
|
||||
var empty bool
|
||||
if s.Version == 6 || s.Version == 7 {
|
||||
empty = true
|
||||
}
|
||||
assertResult(t, res, s, "up", empty)
|
||||
}
|
||||
// Apply all migrations in the down direction.
|
||||
for i := len(sources) - 1; i >= 0; i-- {
|
||||
@ -245,7 +249,11 @@ func TestProviderRun(t *testing.T) {
|
||||
res, err := p.ApplyVersion(ctx, s.Version, false)
|
||||
check.NoError(t, err)
|
||||
// Round-trip the migration result through the database to ensure it's valid.
|
||||
assertResult(t, res, s, "down")
|
||||
var empty bool
|
||||
if s.Version == 6 || s.Version == 7 {
|
||||
empty = true
|
||||
}
|
||||
assertResult(t, res, s, "down", empty)
|
||||
}
|
||||
// Try apply version 1 multiple times
|
||||
_, err := p.ApplyVersion(ctx, 1, true)
|
||||
@ -324,7 +332,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
||||
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||
// Check Results field
|
||||
check.Number(t, len(expected.Applied), 1)
|
||||
assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up")
|
||||
assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
// Check Failed field
|
||||
check.Bool(t, expected.Failed != nil, true)
|
||||
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
|
||||
@ -368,11 +376,11 @@ func TestConcurrentProvider(t *testing.T) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(res) != 1 {
|
||||
t.Errorf("expected 1 result, got %d", len(res))
|
||||
if res == nil {
|
||||
t.Errorf("expected non-nil result, got nil")
|
||||
return
|
||||
}
|
||||
ch <- res[0].Source.Version
|
||||
ch <- res.Source.Version
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
@ -623,13 +631,13 @@ func TestAllowMissing(t *testing.T) {
|
||||
// 4
|
||||
upResult, err := p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(upResult), 1)
|
||||
check.Number(t, upResult[0].Source.Version, 4)
|
||||
check.NotNil(t, upResult)
|
||||
check.Number(t, upResult.Source.Version, 4)
|
||||
// 6
|
||||
upResult, err = p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(upResult), 1)
|
||||
check.Number(t, upResult[0].Source.Version, 6)
|
||||
check.NotNil(t, upResult)
|
||||
check.Number(t, upResult.Source.Version, 6)
|
||||
|
||||
count, err := getGooseVersionCount(db, provider.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
@ -645,21 +653,29 @@ func TestAllowMissing(t *testing.T) {
|
||||
// So migrating down should be the reverse of the applied order:
|
||||
// 6,4,5,3,2,1
|
||||
|
||||
expected := []int64{6, 4, 5, 3, 2, 1}
|
||||
for i, v := range expected {
|
||||
// TODO(mf): this is returning it by the order it was applied.
|
||||
current, err := p.GetDBVersion(ctx)
|
||||
testDownAndVersion := func(wantDBVersion, wantResultVersion int64) {
|
||||
currentVersion, err := p.GetDBVersion(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, current, v)
|
||||
downResult, err := p.Down(ctx)
|
||||
if i == len(expected)-1 {
|
||||
check.HasError(t, provider.ErrVersionNotFound)
|
||||
} else {
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(downResult), 1)
|
||||
check.Number(t, downResult[0].Source.Version, v)
|
||||
}
|
||||
check.Number(t, currentVersion, wantDBVersion)
|
||||
downRes, err := p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, downRes)
|
||||
check.Number(t, downRes.Source.Version, wantResultVersion)
|
||||
}
|
||||
|
||||
// This behaviour may need to change, see the following issues for more details:
|
||||
// - https://github.com/pressly/goose/issues/523
|
||||
// - https://github.com/pressly/goose/issues/402
|
||||
|
||||
testDownAndVersion(6, 6)
|
||||
testDownAndVersion(5, 4) // Ensure the max db version is 5 before down.
|
||||
testDownAndVersion(5, 5)
|
||||
testDownAndVersion(3, 3)
|
||||
testDownAndVersion(2, 2)
|
||||
testDownAndVersion(1, 1)
|
||||
_, err = p.Down(ctx)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true)
|
||||
})
|
||||
}
|
||||
|
||||
@ -674,7 +690,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
|
||||
}
|
||||
|
||||
func TestGoOnly(t *testing.T) {
|
||||
// Not parallel because it modifies global state.
|
||||
// Not parallel because each subtest modifies global state.
|
||||
|
||||
countUser := func(db *sql.DB) int {
|
||||
q := `SELECT count(*)FROM users`
|
||||
@ -688,7 +704,7 @@ func TestGoOnly(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
register := []*provider.Migration{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
||||
DownFnContext: newTxFn("DROP TABLE users"),
|
||||
},
|
||||
@ -713,27 +729,23 @@ func TestGoOnly(t *testing.T) {
|
||||
// Apply migration 1
|
||||
res, err := p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
check.Bool(t, tableExists(t, db, "users"), true)
|
||||
// Apply migration 2
|
||||
res, err = p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
|
||||
check.Number(t, countUser(db), 3)
|
||||
// Rollback migration 2
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
// Rollback migration 1
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
|
||||
// Check table does not exist
|
||||
check.Bool(t, tableExists(t, db, "users"), false)
|
||||
})
|
||||
@ -741,7 +753,7 @@ func TestGoOnly(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
register := []*provider.Migration{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true,
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
||||
DownFnNoTxContext: newDBFn("DROP TABLE users"),
|
||||
},
|
||||
@ -766,27 +778,23 @@ func TestGoOnly(t *testing.T) {
|
||||
// Apply migration 1
|
||||
res, err := p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
check.Bool(t, tableExists(t, db, "users"), true)
|
||||
// Apply migration 2
|
||||
res, err = p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
|
||||
check.Number(t, countUser(db), 3)
|
||||
// Rollback migration 2
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
// Rollback migration 1
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 1)
|
||||
assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down")
|
||||
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
|
||||
// Check table does not exist
|
||||
check.Bool(t, tableExists(t, db, "users"), false)
|
||||
})
|
||||
@ -880,7 +888,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
)
|
||||
g.Go(func() error {
|
||||
for {
|
||||
results, err := provider1.UpByOne(context.Background())
|
||||
result, err := provider1.UpByOne(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
return nil
|
||||
@ -888,17 +896,15 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
check.NoError(t, err)
|
||||
if len(results) != 1 {
|
||||
return fmt.Errorf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
check.NotNil(t, result)
|
||||
mu.Lock()
|
||||
applied = append(applied, results[0].Source.Version)
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
}
|
||||
})
|
||||
g.Go(func() error {
|
||||
for {
|
||||
results, err := provider2.UpByOne(context.Background())
|
||||
result, err := provider2.UpByOne(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
return nil
|
||||
@ -906,11 +912,9 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
check.NoError(t, err)
|
||||
if len(results) != 1 {
|
||||
return fmt.Errorf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
check.NotNil(t, result)
|
||||
mu.Lock()
|
||||
applied = append(applied, results[0].Source.Version)
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
}
|
||||
})
|
||||
@ -986,37 +990,33 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
)
|
||||
g.Go(func() error {
|
||||
for {
|
||||
results, err := provider1.Down(context.Background())
|
||||
result, err := provider1.Down(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if len(results) != 1 {
|
||||
return fmt.Errorf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, result)
|
||||
mu.Lock()
|
||||
applied = append(applied, results[0].Source.Version)
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
}
|
||||
})
|
||||
g.Go(func() error {
|
||||
for {
|
||||
results, err := provider2.Down(context.Background())
|
||||
result, err := provider2.Down(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if len(results) != 1 {
|
||||
return fmt.Errorf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, result)
|
||||
mu.Lock()
|
||||
applied = append(applied, results[0].Source.Version)
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
}
|
||||
})
|
||||
@ -1124,11 +1124,12 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St
|
||||
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
|
||||
}
|
||||
|
||||
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) {
|
||||
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
|
||||
t.Helper()
|
||||
check.NotNil(t, got)
|
||||
check.Equal(t, got.Source, source)
|
||||
check.Equal(t, got.Direction, direction)
|
||||
check.Equal(t, got.Empty, false)
|
||||
check.Equal(t, got.Empty, isEmpty)
|
||||
check.Bool(t, got.Error == nil, true)
|
||||
check.Bool(t, got.Duration > 0, true)
|
||||
}
|
||||
@ -1136,7 +1137,7 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S
|
||||
func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) {
|
||||
t.Helper()
|
||||
check.Equal(t, got.Type, typ)
|
||||
check.Equal(t, got.Fullpath, name)
|
||||
check.Equal(t, got.Path, name)
|
||||
check.Equal(t, got.Version, version)
|
||||
switch got.Type {
|
||||
case provider.TypeGo:
|
||||
|
@ -1,96 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
||||
if version < 1 {
|
||||
return nil, errors.New("version must be greater than zero")
|
||||
}
|
||||
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if p.cfg.noVersioning {
|
||||
// Short circuit if versioning is disabled and apply all migrations.
|
||||
return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne)
|
||||
}
|
||||
|
||||
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
|
||||
// support the out-of-order (allow missing) feature. For users who don't use this feature, we
|
||||
// could just query the database for the current version and then apply migrations that are
|
||||
// greater than that version.
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbMaxVersion := dbMigrations[0].Version
|
||||
// lookupAppliedInDB is a map of all applied migrations in the database.
|
||||
lookupAppliedInDB := make(map[int64]bool)
|
||||
for _, m := range dbMigrations {
|
||||
lookupAppliedInDB[m.Version] = true
|
||||
}
|
||||
|
||||
missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion)
|
||||
|
||||
// feature(mf): It is very possible someone may want to apply ONLY new migrations and skip
|
||||
// missing migrations entirely. At the moment this is not supported, but leaving this comment
|
||||
// because that's where that logic will be handled.
|
||||
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
||||
var collected []string
|
||||
for _, v := range missingMigrations {
|
||||
collected = append(collected, v.filename)
|
||||
}
|
||||
msg := "migration"
|
||||
if len(collected) > 1 {
|
||||
msg += "s"
|
||||
}
|
||||
return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]",
|
||||
len(missingMigrations), msg, strings.Join(collected, ","))
|
||||
}
|
||||
|
||||
var migrationsToApply []*migration
|
||||
if p.cfg.allowMissing {
|
||||
for _, v := range missingMigrations {
|
||||
m, err := p.getMigration(v.versionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrationsToApply = append(migrationsToApply, m)
|
||||
}
|
||||
}
|
||||
// filter all migrations with a version greater than the supplied version (min) and less than or
|
||||
// equal to the requested version (max).
|
||||
for _, m := range p.migrations {
|
||||
if lookupAppliedInDB[m.Source.Version] {
|
||||
continue
|
||||
}
|
||||
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
|
||||
migrationsToApply = append(migrationsToApply, m)
|
||||
}
|
||||
}
|
||||
|
||||
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
|
||||
// transaction. The default is to apply each migration sequentially on its own.
|
||||
// https://github.com/pressly/goose/issues/222
|
||||
//
|
||||
// Note, we can't use a single transaction for all migrations because some may have to be run in
|
||||
// their own transaction.
|
||||
|
||||
return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne)
|
||||
}
|
@ -41,36 +41,23 @@ func (t MigrationType) String() string {
|
||||
|
||||
// Source represents a single migration source.
|
||||
//
|
||||
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
|
||||
// the migration has a corresponding file on disk. It will be empty if the migration was registered
|
||||
// manually.
|
||||
// The Path field may be empty if the migration was registered manually. This is typically the case
|
||||
// for Go migrations registered using the [WithGoMigration] option.
|
||||
type Source struct {
|
||||
// Type is the type of migration.
|
||||
Type MigrationType
|
||||
// Full path to the migration file.
|
||||
//
|
||||
// Example: /path/to/migrations/001_create_users_table.sql
|
||||
Fullpath string
|
||||
// Version is the version of the migration.
|
||||
Type MigrationType
|
||||
Path string
|
||||
Version int64
|
||||
}
|
||||
|
||||
// MigrationResult is the result of a single migration operation.
|
||||
//
|
||||
// Note, the caller is responsible for checking the Error field for any errors that occurred while
|
||||
// running the migration. If the Error field is not nil, the migration failed.
|
||||
type MigrationResult struct {
|
||||
Source Source
|
||||
Duration time.Duration
|
||||
Direction string
|
||||
// Empty is true if the file was valid, but no statements to apply. These are still versioned
|
||||
// migrations, but typically have no effect on the database.
|
||||
//
|
||||
// For SQL migrations, this means there was a valid .sql file but contained no statements. For
|
||||
// Go migrations, this means the function was nil.
|
||||
// Empty indicates no action was taken during the migration, but it was still versioned. For
|
||||
// SQL, it means no statements; for Go, it's a nil function.
|
||||
Empty bool
|
||||
|
||||
// Error is any error that occurred while running the migration.
|
||||
// Error is only set if the migration failed.
|
||||
Error error
|
||||
}
|
||||
|
||||
@ -78,22 +65,20 @@ type MigrationResult struct {
|
||||
type State string
|
||||
|
||||
const (
|
||||
// StatePending represents a migration that is on the filesystem, but not in the database.
|
||||
// StatePending is a migration that exists on the filesystem, but not in the database.
|
||||
StatePending State = "pending"
|
||||
// StateApplied represents a migration that is in BOTH the database and on the filesystem.
|
||||
// StateApplied is a migration that has been applied to the database and exists on the
|
||||
// filesystem.
|
||||
StateApplied State = "applied"
|
||||
|
||||
// StateUntracked represents a migration that is in the database, but not on the filesystem.
|
||||
// StateUntracked State = "untracked"
|
||||
// TODO(mf): we could also add a third state for untracked migrations. This would be useful for
|
||||
// migrations that were manually applied to the database, but not versioned. Or the Source was
|
||||
// deleted, but the migration still exists in the database. StateUntracked State = "untracked"
|
||||
)
|
||||
|
||||
// MigrationStatus represents the status of a single migration.
|
||||
type MigrationStatus struct {
|
||||
// State is the state of the migration.
|
||||
State State
|
||||
// AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or
|
||||
// [StateUntracked].
|
||||
Source Source
|
||||
State State
|
||||
AppliedAt time.Time
|
||||
// Source is the migration source. Only set if the state is [StatePending] or [StateApplied].
|
||||
Source Source
|
||||
}
|
||||
|
@ -61,6 +61,25 @@ func TestStore(t *testing.T) {
|
||||
check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
|
||||
})
|
||||
})
|
||||
t.Run("ListMigrations", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||
check.NoError(t, err)
|
||||
store, err := sqladapter.NewStore("sqlite3", "foo")
|
||||
check.NoError(t, err)
|
||||
err = store.CreateVersionTable(context.Background(), db)
|
||||
check.NoError(t, err)
|
||||
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1))
|
||||
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3))
|
||||
check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2))
|
||||
res, err := store.ListMigrations(context.Background(), db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), 3)
|
||||
// Check versions are in descending order: [2, 3, 1]
|
||||
check.Number(t, res[0].Version, 2)
|
||||
check.Number(t, res[1].Version, 3)
|
||||
check.Number(t, res[2].Version, 1)
|
||||
})
|
||||
}
|
||||
|
||||
// testStore tests various store operations.
|
||||
|
Loading…
x
Reference in New Issue
Block a user