diff --git a/README.md b/README.md index 88ca418..c6b6008 100644 --- a/README.md +++ b/README.md @@ -74,4 +74,4 @@ A sample dbconf.yml looks like Here, `development` specifies the name of the configuration, and the `driver` and `open` elements are passed directly to database/sql to access the specified database. -You may include as many configurations as you like, and you can use the `--config` command line option to specify which one to use. goose defaults to using a configuration called `development`. \ No newline at end of file +You may include as many configurations as you like, and you can use the `--config` command line option to specify which one to use. goose defaults to using a configuration called `development`. diff --git a/cmd/goose.go b/cmd/goose.go index 98243a1..3d22a94 100644 --- a/cmd/goose.go +++ b/cmd/goose.go @@ -27,10 +27,16 @@ type DBVersion struct { TStamp time.Time } +type Migration struct { + Next int // next version, or -1 if none + Previous int // previous version, -1 if none + Source string // .go or .sql script +} + type MigrationMap struct { - Versions []int // sorted slice of version keys - Sources map[int]string // sources (.sql or .go) keyed by version - Direction bool // sort direction: true -> Up, false -> Down + Versions []int // sorted slice of version keys + Migrations map[int]Migration // sources (.sql or .go) keyed by version + Direction bool // sort direction: true -> Up, false -> Down } var dbFolder = flag.String("db", "db", "folder containing db info") @@ -53,7 +59,7 @@ func dbConfFromFile(path, envtype string) (*DBConf, error) { f, err := yaml.ReadFile(path) if err != nil { - log.Fatal(err) + return nil, err } drv, derr := f.Get(fmt.Sprintf("%s.driver", envtype)) @@ -85,12 +91,12 @@ func runMigrations(conf *DBConf, target int) { log.Fatal("couldn't get/set DB version") } - migrations, err := collectMigrations(path.Join(*dbFolder, "migrations"), currentVersion) + mm, err := collectMigrations(path.Join(*dbFolder, "migrations"), currentVersion) if err != nil { log.Fatal(err) } - if len(migrations.Versions) == 0 { + if len(mm.Versions) == 0 { fmt.Printf("goose: no migrations to run. current version: %d\n", currentVersion) return } @@ -98,18 +104,18 @@ func runMigrations(conf *DBConf, target int) { fmt.Printf("goose: migrating db configuration '%v', current version: %d, target: %d\n", conf.Name, currentVersion, *targetVersion) - for _, v := range migrations.Versions { + for _, v := range mm.Versions { var numStatements int var e error - filepath := migrations.Sources[v] + filepath := mm.Migrations[v].Source switch path.Ext(filepath) { case ".go": - numStatements, e = runGoMigration(conf, filepath, v, migrations.Direction) + numStatements, e = runGoMigration(conf, filepath, v, mm.Direction) case ".sql": - numStatements, e = runSQLMigration(db, filepath, v, migrations.Direction) + numStatements, e = runSQLMigration(db, filepath, v, mm.Direction) } if e != nil { @@ -135,7 +141,7 @@ func collectMigrations(dirpath string, currentVersion int) (mm *MigrationMap, er } mm = &MigrationMap{ - Sources: make(map[int]string), + Migrations: make(map[int]Migration), } // extract the numeric component of each migration, @@ -153,7 +159,7 @@ func collectMigrations(dirpath string, currentVersion int) (mm *MigrationMap, er continue } - if _, ok := mm.Sources[v]; ok { + if _, ok := mm.Migrations[v]; ok { log.Fatalf("more than one file specifies the migration for version %d (%s and %s)", v, mm.Versions[v], path.Join(dirpath, name)) } @@ -190,7 +196,11 @@ func versionFilter(v, current, target int) bool { func (m *MigrationMap) Append(v int, source string) { m.Versions = append(m.Versions, v) - m.Sources[v] = source + m.Migrations[v] = Migration{ + Next: -1, + Previous: -1, + Source: source, + } } func (m *MigrationMap) Sort(currentVersion int) { @@ -208,6 +218,24 @@ func (m *MigrationMap) Sort(currentVersion int) { m.Versions[i], m.Versions[j] = m.Versions[j], m.Versions[i] } } + + // now that we're sorted in the appropriate direction, + // populate next and previous for each migration + // + // work around http://code.google.com/p/go/issues/detail?id=3117 + previousV := -1 + for _, v := range m.Versions { + cur := m.Migrations[v] + cur.Previous = previousV + + // if a migration exists at prev, its next is now v + if prev, ok := m.Migrations[previousV]; ok { + prev.Next = v + m.Migrations[previousV] = prev + } + + previousV = v + } } // look for migration scripts with names in the form: @@ -247,7 +275,7 @@ func ensureDBVersion(db *sql.DB) (int, error) { );` insert := "INSERT INTO goose_db_version (version_id) VALUES (0);" - for _, str := range [2]string{create, insert} { + for _, str := range []string{create, insert} { if _, err := txn.Exec(str); err != nil { txn.Rollback() return 0, err