Merge pull request #8 from pressly/fix_go_up

Fix go migration up
pull/10/head
Vojtech Vitek 2016-09-29 19:41:40 -04:00 committed by GitHub
commit aad3e6a24e
7 changed files with 88 additions and 144 deletions

18
down.go
View File

@ -1,9 +1,6 @@
package goose
import (
"database/sql"
"fmt"
)
import "database/sql"
func Down(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
@ -11,15 +8,14 @@ func Down(db *sql.DB, dir string) error {
return err
}
previous, err := GetPreviousDBVersion(dir, current)
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
if err != nil {
if err == ErrNoPreviousVersion {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
}
return err
}
return err
}
migrations.Sort(false) // descending, Next will be Previous
previous, err := migrations.Next(current)
if err != nil {
return err
}

View File

@ -18,10 +18,9 @@ func checkVersionDuplicates(dir string) error {
return err
}
// Try sorting all migrations, so we get panic on any duplicates.
ms := migrationSorter(migrations)
ms.Sort(true)
ms.Sort(false)
// try both directions
migrations.Sort(false)
migrations.Sort(true)
return nil
}

View File

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"sort"
@ -18,7 +17,10 @@ import (
var (
ErrNoPreviousVersion = errors.New("no previous version found")
ErrNoNextVersion = errors.New("no next version found")
goMigrations []*Migration
MaxVersion = 9223372036854775807 // max(int64)
goMigrations []*Migration
)
type MigrationRecord struct {
@ -36,18 +38,59 @@ type Migration struct {
Down func(*sql.Tx) error // Down go migration function
}
type migrationSorter []*Migration
type Migrations []*Migration
// helpers so we can use pkg sort
func (ms migrationSorter) Len() int { return len(ms) }
func (ms migrationSorter) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms migrationSorter) Less(i, j int) bool {
func (ms Migrations) Len() int { return len(ms) }
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms Migrations) Less(i, j int) bool {
if ms[i].Version == ms[j].Version {
log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source)
}
return ms[i].Version < ms[j].Version
}
func (ms Migrations) Sort(up bool) {
// sort ascending or descending by version
if up {
sort.Sort(ms)
} else {
sort.Sort(sort.Reverse(ms))
}
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range ms {
prev := int64(-1)
if i > 0 {
prev = ms[i-1].Version
ms[i-1].Next = m.Version
}
ms[i].Previous = prev
}
}
func (ms Migrations) Last() (int64, error) {
if len(ms) == 0 {
return -1, ErrNoNextVersion
}
return ms[len(ms)-1].Version, nil
}
func (ms Migrations) Next(current int64) (int64, error) {
exceptLast := ms[:len(ms)-1]
for i, migration := range exceptLast {
if migration.Version == current {
return ms[i+1].Version, nil
}
}
return -1, ErrNoNextVersion
}
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
v, _ := NumericComponent(filename)
@ -77,7 +120,7 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
return nil
}
ms := migrationSorter(migrations)
ms := Migrations(migrations)
direction := current < target
ms.Sort(direction)
@ -122,7 +165,7 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
// collect all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version
func CollectMigrations(dirpath string, current, target int64) (m []*Migration, err error) {
func CollectMigrations(dirpath string, current, target int64) (m Migrations, err error) {
// extract the numeric component of each migration,
// filter out any uninteresting files,
@ -169,27 +212,6 @@ func versionFilter(v, current, target int64) bool {
return false
}
func (ms migrationSorter) Sort(direction bool) {
// sort ascending or descending by version
if direction {
sort.Sort(ms)
} else {
sort.Sort(sort.Reverse(ms))
}
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range ms {
prev := int64(-1)
if i > 0 {
prev = ms[i-1].Version
ms[i-1].Next = m.Version
}
ms[i].Previous = prev
}
}
// look for migration scripts with names in the form:
// XXX_descriptivename.ext
// where XXX specifies the version number
@ -298,95 +320,6 @@ func GetDBVersion(db *sql.DB) (int64, error) {
return version, nil
}
func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) {
previous = -1
sawGivenVersion := false
filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v > previous && v < version {
previous = v
}
if v == version {
sawGivenVersion = true
}
}
}
return nil
})
if previous == -1 {
if sawGivenVersion {
// the given version is (likely) valid but we didn't find
// anything before it.
// 'previous' must reflect that no migrations have been applied.
previous = 0
} else {
err = ErrNoPreviousVersion
}
}
return
}
func GetNextDBVersion(dirpath string, version int64) (next int64, err error) {
next = 9223372036854775807 // max(int64)
filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v < next && v > version {
next = v
}
}
}
return nil
})
if next == 9223372036854775807 {
next = version
err = ErrNoNextVersion
}
return
}
// helper to identify the most recent possible version
// within a folder of migration scripts
func GetMostRecentDBVersion(dirpath string) (version int64, err error) {
version = -1
filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
if walkerr != nil {
return walkerr
}
if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v > version {
version = v
}
}
}
return nil
})
if version == -1 {
err = errors.New("no valid version found")
}
return
}
func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {
if migrationType != "go" && migrationType != "sql" {

View File

@ -10,7 +10,7 @@ func newMigration(v int64, src string) *Migration {
func TestMigrationMapSortUp(t *testing.T) {
ms := migrationSorter{}
ms := Migrations{}
// insert in any order
ms = append(ms, newMigration(20120000, "test"))
@ -27,7 +27,7 @@ func TestMigrationMapSortUp(t *testing.T) {
func TestMigrationMapSortDown(t *testing.T) {
ms := migrationSorter{}
ms := Migrations{}
// insert in any order
ms = append(ms, newMigration(20120000, "test"))
@ -42,7 +42,7 @@ func TestMigrationMapSortDown(t *testing.T) {
validateMigrationSort(t, ms, sorted)
}
func validateMigrationSort(t *testing.T, ms migrationSorter, sorted []int64) {
func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) {
for i, m := range ms {
if sorted[i] != m.Version {

View File

@ -10,7 +10,13 @@ func Redo(db *sql.DB, dir string) error {
return err
}
previous, err := GetPreviousDBVersion(dir, current)
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(false) // descending, Next will be Previous
previous, err := migrations.Next(current)
if err != nil {
return err
}

View File

@ -14,9 +14,7 @@ func Status(db *sql.DB, dir string) error {
if err != nil {
return err
}
ms := migrationSorter(migrations)
ms.Sort(true)
migrations.Sort(true)
// must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil {
@ -26,8 +24,8 @@ func Status(db *sql.DB, dir string) error {
fmt.Println("goose: status")
fmt.Println(" Applied At Migration")
fmt.Println(" =======================================")
for _, m := range ms {
printMigrationStatus(db, m.Version, filepath.Base(m.Source))
for _, migration := range migrations {
printMigrationStatus(db, migration.Version, filepath.Base(migration.Source))
}
return nil

16
up.go
View File

@ -6,7 +6,13 @@ import (
)
func Up(db *sql.DB, dir string) error {
target, err := GetMostRecentDBVersion(dir)
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)
target, err := migrations.Last()
if err != nil {
return err
}
@ -18,12 +24,18 @@ func Up(db *sql.DB, dir string) error {
}
func UpByOne(db *sql.DB, dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)
current, err := GetDBVersion(db)
if err != nil {
return err
}
next, err := GetNextDBVersion(dir, current)
next, err := migrations.Next(current)
if err != nil {
if err == ErrNoNextVersion {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)