migrations: capture next/previous versions - can't assume migration IDs will be sequential

This commit is contained in:
Liam Staskawicz 2012-12-11 09:24:51 -08:00
parent 8394b86b46
commit b373fe57bd
2 changed files with 43 additions and 15 deletions

View File

@ -74,4 +74,4 @@ A sample dbconf.yml looks like
Here, `development` specifies the name of the configuration, and the `driver` and `open` elements are passed directly to database/sql to access the specified database.
You may include as many configurations as you like, and you can use the `--config` command line option to specify which one to use. goose defaults to using a configuration called `development`.
You may include as many configurations as you like, and you can use the `--config` command line option to specify which one to use. goose defaults to using a configuration called `development`.

View File

@ -27,10 +27,16 @@ type DBVersion struct {
TStamp time.Time
}
type Migration struct {
Next int // next version, or -1 if none
Previous int // previous version, -1 if none
Source string // .go or .sql script
}
type MigrationMap struct {
Versions []int // sorted slice of version keys
Sources map[int]string // sources (.sql or .go) keyed by version
Direction bool // sort direction: true -> Up, false -> Down
Versions []int // sorted slice of version keys
Migrations map[int]Migration // sources (.sql or .go) keyed by version
Direction bool // sort direction: true -> Up, false -> Down
}
var dbFolder = flag.String("db", "db", "folder containing db info")
@ -53,7 +59,7 @@ func dbConfFromFile(path, envtype string) (*DBConf, error) {
f, err := yaml.ReadFile(path)
if err != nil {
log.Fatal(err)
return nil, err
}
drv, derr := f.Get(fmt.Sprintf("%s.driver", envtype))
@ -85,12 +91,12 @@ func runMigrations(conf *DBConf, target int) {
log.Fatal("couldn't get/set DB version")
}
migrations, err := collectMigrations(path.Join(*dbFolder, "migrations"), currentVersion)
mm, err := collectMigrations(path.Join(*dbFolder, "migrations"), currentVersion)
if err != nil {
log.Fatal(err)
}
if len(migrations.Versions) == 0 {
if len(mm.Versions) == 0 {
fmt.Printf("goose: no migrations to run. current version: %d\n", currentVersion)
return
}
@ -98,18 +104,18 @@ func runMigrations(conf *DBConf, target int) {
fmt.Printf("goose: migrating db configuration '%v', current version: %d, target: %d\n",
conf.Name, currentVersion, *targetVersion)
for _, v := range migrations.Versions {
for _, v := range mm.Versions {
var numStatements int
var e error
filepath := migrations.Sources[v]
filepath := mm.Migrations[v].Source
switch path.Ext(filepath) {
case ".go":
numStatements, e = runGoMigration(conf, filepath, v, migrations.Direction)
numStatements, e = runGoMigration(conf, filepath, v, mm.Direction)
case ".sql":
numStatements, e = runSQLMigration(db, filepath, v, migrations.Direction)
numStatements, e = runSQLMigration(db, filepath, v, mm.Direction)
}
if e != nil {
@ -135,7 +141,7 @@ func collectMigrations(dirpath string, currentVersion int) (mm *MigrationMap, er
}
mm = &MigrationMap{
Sources: make(map[int]string),
Migrations: make(map[int]Migration),
}
// extract the numeric component of each migration,
@ -153,7 +159,7 @@ func collectMigrations(dirpath string, currentVersion int) (mm *MigrationMap, er
continue
}
if _, ok := mm.Sources[v]; ok {
if _, ok := mm.Migrations[v]; ok {
log.Fatalf("more than one file specifies the migration for version %d (%s and %s)",
v, mm.Versions[v], path.Join(dirpath, name))
}
@ -190,7 +196,11 @@ func versionFilter(v, current, target int) bool {
func (m *MigrationMap) Append(v int, source string) {
m.Versions = append(m.Versions, v)
m.Sources[v] = source
m.Migrations[v] = Migration{
Next: -1,
Previous: -1,
Source: source,
}
}
func (m *MigrationMap) Sort(currentVersion int) {
@ -208,6 +218,24 @@ func (m *MigrationMap) Sort(currentVersion int) {
m.Versions[i], m.Versions[j] = m.Versions[j], m.Versions[i]
}
}
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
//
// work around http://code.google.com/p/go/issues/detail?id=3117
previousV := -1
for _, v := range m.Versions {
cur := m.Migrations[v]
cur.Previous = previousV
// if a migration exists at prev, its next is now v
if prev, ok := m.Migrations[previousV]; ok {
prev.Next = v
m.Migrations[previousV] = prev
}
previousV = v
}
}
// look for migration scripts with names in the form:
@ -247,7 +275,7 @@ func ensureDBVersion(db *sql.DB) (int, error) {
);`
insert := "INSERT INTO goose_db_version (version_id) VALUES (0);"
for _, str := range [2]string{create, insert} {
for _, str := range []string{create, insert} {
if _, err := txn.Exec(str); err != nil {
txn.Rollback()
return 0, err