diff --git a/down.go b/down.go index 3b5835c..6315623 100644 --- a/down.go +++ b/down.go @@ -1,9 +1,6 @@ package goose -import ( - "database/sql" - "fmt" -) +import "database/sql" func Down(db *sql.DB, dir string) error { current, err := GetDBVersion(db) @@ -11,15 +8,14 @@ func Down(db *sql.DB, dir string) error { return err } - previous, err := GetPreviousDBVersion(dir, current) + migrations, err := CollectMigrations(dir, minVersion, maxVersion) if err != nil { - if err != nil { - if err == ErrNoPreviousVersion { - fmt.Printf("goose: no migrations to run. current version: %d\n", current) - } - return err - } + return err + } + migrations.Sort(false) // descending, Next will be Previous + previous, err := migrations.Next(current) + if err != nil { return err } diff --git a/goose.go b/goose.go index 6dc0714..811d416 100644 --- a/goose.go +++ b/goose.go @@ -18,10 +18,9 @@ func checkVersionDuplicates(dir string) error { return err } - // Try sorting all migrations, so we get panic on any duplicates. - ms := migrationSorter(migrations) - ms.Sort(true) - ms.Sort(false) + // try both directions + migrations.Sort(false) + migrations.Sort(true) return nil } diff --git a/migrate.go b/migrate.go index 9478162..a1fcc73 100644 --- a/migrate.go +++ b/migrate.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log" - "os" "path/filepath" "runtime" "sort" @@ -18,7 +17,10 @@ import ( var ( ErrNoPreviousVersion = errors.New("no previous version found") ErrNoNextVersion = errors.New("no next version found") - goMigrations []*Migration + + MaxVersion = 9223372036854775807 // max(int64) + + goMigrations []*Migration ) type MigrationRecord struct { @@ -36,18 +38,59 @@ type Migration struct { Down func(*sql.Tx) error // Down go migration function } -type migrationSorter []*Migration +type Migrations []*Migration // helpers so we can use pkg sort -func (ms migrationSorter) Len() int { return len(ms) } -func (ms migrationSorter) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] } -func (ms migrationSorter) Less(i, j int) bool { +func (ms Migrations) Len() int { return len(ms) } +func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] } +func (ms Migrations) Less(i, j int) bool { if ms[i].Version == ms[j].Version { log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source) } return ms[i].Version < ms[j].Version } +func (ms Migrations) Sort(up bool) { + + // sort ascending or descending by version + if up { + sort.Sort(ms) + } else { + sort.Sort(sort.Reverse(ms)) + } + + // now that we're sorted in the appropriate direction, + // populate next and previous for each migration + for i, m := range ms { + prev := int64(-1) + if i > 0 { + prev = ms[i-1].Version + ms[i-1].Next = m.Version + } + ms[i].Previous = prev + } +} + +func (ms Migrations) Last() (int64, error) { + if len(ms) == 0 { + return -1, ErrNoNextVersion + } + + return ms[len(ms)-1].Version, nil +} + +func (ms Migrations) Next(current int64) (int64, error) { + exceptLast := ms[:len(ms)-1] + + for i, migration := range exceptLast { + if migration.Version == current { + return ms[i+1].Version, nil + } + } + + return -1, ErrNoNextVersion +} + func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { _, filename, _, _ := runtime.Caller(1) v, _ := NumericComponent(filename) @@ -77,7 +120,7 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) { return nil } - ms := migrationSorter(migrations) + ms := Migrations(migrations) direction := current < target ms.Sort(direction) @@ -122,7 +165,7 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) { // collect all the valid looking migration scripts in the // migrations folder and go func registry, and key them by version -func CollectMigrations(dirpath string, current, target int64) (m []*Migration, err error) { +func CollectMigrations(dirpath string, current, target int64) (m Migrations, err error) { // extract the numeric component of each migration, // filter out any uninteresting files, @@ -169,27 +212,6 @@ func versionFilter(v, current, target int64) bool { return false } -func (ms migrationSorter) Sort(direction bool) { - - // sort ascending or descending by version - if direction { - sort.Sort(ms) - } else { - sort.Sort(sort.Reverse(ms)) - } - - // now that we're sorted in the appropriate direction, - // populate next and previous for each migration - for i, m := range ms { - prev := int64(-1) - if i > 0 { - prev = ms[i-1].Version - ms[i-1].Next = m.Version - } - ms[i].Previous = prev - } -} - // look for migration scripts with names in the form: // XXX_descriptivename.ext // where XXX specifies the version number @@ -298,95 +320,6 @@ func GetDBVersion(db *sql.DB) (int64, error) { return version, nil } -func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) { - - previous = -1 - sawGivenVersion := false - - filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { - - if !info.IsDir() { - if v, e := NumericComponent(name); e == nil { - if v > previous && v < version { - previous = v - } - if v == version { - sawGivenVersion = true - } - } - } - - return nil - }) - - if previous == -1 { - if sawGivenVersion { - // the given version is (likely) valid but we didn't find - // anything before it. - // 'previous' must reflect that no migrations have been applied. - previous = 0 - } else { - err = ErrNoPreviousVersion - } - } - - return -} - -func GetNextDBVersion(dirpath string, version int64) (next int64, err error) { - - next = 9223372036854775807 // max(int64) - - filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { - - if !info.IsDir() { - if v, e := NumericComponent(name); e == nil { - if v < next && v > version { - next = v - } - } - } - - return nil - }) - - if next == 9223372036854775807 { - next = version - err = ErrNoNextVersion - } - - return -} - -// helper to identify the most recent possible version -// within a folder of migration scripts -func GetMostRecentDBVersion(dirpath string) (version int64, err error) { - - version = -1 - - filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { - if walkerr != nil { - return walkerr - } - - if !info.IsDir() { - if v, e := NumericComponent(name); e == nil { - if v > version { - version = v - } - } - } - - return nil - }) - - if version == -1 { - err = errors.New("no valid version found") - } - - return -} - func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { if migrationType != "go" && migrationType != "sql" { diff --git a/migrate_test.go b/migrate_test.go index f0ef2c1..463c4c0 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -10,7 +10,7 @@ func newMigration(v int64, src string) *Migration { func TestMigrationMapSortUp(t *testing.T) { - ms := migrationSorter{} + ms := Migrations{} // insert in any order ms = append(ms, newMigration(20120000, "test")) @@ -27,7 +27,7 @@ func TestMigrationMapSortUp(t *testing.T) { func TestMigrationMapSortDown(t *testing.T) { - ms := migrationSorter{} + ms := Migrations{} // insert in any order ms = append(ms, newMigration(20120000, "test")) @@ -42,7 +42,7 @@ func TestMigrationMapSortDown(t *testing.T) { validateMigrationSort(t, ms, sorted) } -func validateMigrationSort(t *testing.T, ms migrationSorter, sorted []int64) { +func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) { for i, m := range ms { if sorted[i] != m.Version { diff --git a/redo.go b/redo.go index 4b34117..6a3a204 100644 --- a/redo.go +++ b/redo.go @@ -10,7 +10,13 @@ func Redo(db *sql.DB, dir string) error { return err } - previous, err := GetPreviousDBVersion(dir, current) + migrations, err := CollectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + migrations.Sort(false) // descending, Next will be Previous + + previous, err := migrations.Next(current) if err != nil { return err } diff --git a/status.go b/status.go index 8138f97..9a8fbbc 100644 --- a/status.go +++ b/status.go @@ -14,9 +14,7 @@ func Status(db *sql.DB, dir string) error { if err != nil { return err } - - ms := migrationSorter(migrations) - ms.Sort(true) + migrations.Sort(true) // must ensure that the version table exists if we're running on a pristine DB if _, err := EnsureDBVersion(db); err != nil { @@ -26,8 +24,8 @@ func Status(db *sql.DB, dir string) error { fmt.Println("goose: status") fmt.Println(" Applied At Migration") fmt.Println(" =======================================") - for _, m := range ms { - printMigrationStatus(db, m.Version, filepath.Base(m.Source)) + for _, migration := range migrations { + printMigrationStatus(db, migration.Version, filepath.Base(migration.Source)) } return nil diff --git a/up.go b/up.go index f1b348d..9ff520d 100644 --- a/up.go +++ b/up.go @@ -6,7 +6,13 @@ import ( ) func Up(db *sql.DB, dir string) error { - target, err := GetMostRecentDBVersion(dir) + migrations, err := CollectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + migrations.Sort(true) + + target, err := migrations.Last() if err != nil { return err } @@ -18,12 +24,18 @@ func Up(db *sql.DB, dir string) error { } func UpByOne(db *sql.DB, dir string) error { + migrations, err := CollectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + migrations.Sort(true) + current, err := GetDBVersion(db) if err != nil { return err } - next, err := GetNextDBVersion(dir, current) + next, err := migrations.Next(current) if err != nil { if err == ErrNoNextVersion { fmt.Printf("goose: no migrations to run. current version: %d\n", current)