diff --git a/down.go b/down.go index 6315623..e5e819a 100644 --- a/down.go +++ b/down.go @@ -1,27 +1,25 @@ package goose -import "database/sql" +import ( + "database/sql" + "fmt" +) func Down(db *sql.DB, dir string) error { - current, err := GetDBVersion(db) + currentVersion, err := GetDBVersion(db) if err != nil { return err } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) - if err != nil { - return err - } - migrations.Sort(false) // descending, Next will be Previous - - previous, err := migrations.Next(current) + migrations, err := collectMigrations(dir, minVersion, maxVersion) if err != nil { return err } - if err = RunMigrations(db, dir, previous); err != nil { - return err + current, err := migrations.Current(currentVersion) + if err != nil { + return fmt.Errorf("no migration %v", currentVersion) } - return nil + return current.Down(db) } diff --git a/goose.go b/goose.go index 811d416..604e29f 100644 --- a/goose.go +++ b/goose.go @@ -12,23 +12,7 @@ var ( maxVersion = int64((1 << 63) - 1) ) -func checkVersionDuplicates(dir string) error { - migrations, err := CollectMigrations(dir, minVersion, maxVersion) - if err != nil { - return err - } - - // try both directions - migrations.Sort(false) - migrations.Sort(true) - return nil -} - func Run(command string, db *sql.DB, dir string, args ...string) error { - if err := checkVersionDuplicates(dir); err != nil { - return err - } - switch command { case "up": if err := Up(db, dir); err != nil { diff --git a/migrate.go b/migrate.go index a1fcc73..201e110 100644 --- a/migrate.go +++ b/migrate.go @@ -8,36 +8,17 @@ import ( "path/filepath" "runtime" "sort" - "strconv" - "strings" - "text/template" - "time" ) var ( - ErrNoPreviousVersion = errors.New("no previous version found") - ErrNoNextVersion = errors.New("no next version found") + ErrNoCurrentVersion = errors.New("no current version found") + ErrNoNextVersion = errors.New("no next version found") MaxVersion = 9223372036854775807 // max(int64) goMigrations []*Migration ) -type MigrationRecord struct { - VersionId int64 - TStamp time.Time - IsApplied bool // was this a result of up() or down() -} - -type Migration struct { - Version int64 - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - Source string // path to .sql script - Up func(*sql.Tx) error // Up go migration function - Down func(*sql.Tx) error // Down go migration function -} - type Migrations []*Migration // helpers so we can use pkg sort @@ -50,122 +31,54 @@ func (ms Migrations) Less(i, j int) bool { return ms[i].Version < ms[j].Version } -func (ms Migrations) Sort(up bool) { - - // sort ascending or descending by version - if up { - sort.Sort(ms) - } else { - sort.Sort(sort.Reverse(ms)) - } - - // now that we're sorted in the appropriate direction, - // populate next and previous for each migration - for i, m := range ms { - prev := int64(-1) - if i > 0 { - prev = ms[i-1].Version - ms[i-1].Next = m.Version - } - ms[i].Previous = prev - } -} - -func (ms Migrations) Last() (int64, error) { - if len(ms) == 0 { - return -1, ErrNoNextVersion - } - - return ms[len(ms)-1].Version, nil -} - -func (ms Migrations) Next(current int64) (int64, error) { - exceptLast := ms[:len(ms)-1] - - for i, migration := range exceptLast { +func (ms Migrations) Current(current int64) (*Migration, error) { + for i, migration := range ms { if migration.Version == current { - return ms[i+1].Version, nil + return ms[i], nil } } - return -1, ErrNoNextVersion + return nil, ErrNoCurrentVersion +} + +func (ms Migrations) Next(current int64) (*Migration, error) { + for i, migration := range ms { + if migration.Version > current { + return ms[i], nil + } + } + + return nil, ErrNoNextVersion +} + +func (ms Migrations) Last() (*Migration, error) { + if len(ms) == 0 { + return nil, ErrNoNextVersion + } + + return ms[len(ms)-1], nil +} + +func (ms Migrations) String() string { + str := "" + for _, m := range ms { + str += fmt.Sprintln(m) + } + return str } func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { _, filename, _, _ := runtime.Caller(1) v, _ := NumericComponent(filename) - migration := &Migration{Version: v, Next: -1, Previous: -1, Up: up, Down: down, Source: filename} + migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename} goMigrations = append(goMigrations, migration) } -func RunMigrations(db *sql.DB, dir string, target int64) (err error) { - current, err := EnsureDBVersion(db) - if err != nil { - return err - } - - if current == target { - fmt.Printf("goose: no migrations to run. current version: %d. target version: %d\n", current, target) - return nil - } - - migrations, err := CollectMigrations(dir, current, target) - if err != nil { - return err - } - - if len(migrations) == 0 { - fmt.Printf("goose: no migrations to run. current version: %d\n", current) - return nil - } - - ms := Migrations(migrations) - direction := current < target - ms.Sort(direction) - - fmt.Printf("goose: migrating db, current version: %d, target: %d\n", current, target) - - for _, m := range ms { - - switch filepath.Ext(m.Source) { - case ".sql": - if err = runSQLMigration(db, m.Source, m.Version, direction); err != nil { - return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err)) - } - - case ".go": - tx, err := db.Begin() - if err != nil { - log.Fatal("db.Begin: ", err) - } - - fn := m.Up - if !direction { - fn = m.Down - } - if fn != nil { - if err := fn(tx); err != nil { - tx.Rollback() - log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(m.Source), err) - return err - } - } - - if err = FinalizeMigration(tx, direction, m.Version); err != nil { - log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(m.Source), err) - } - } - - fmt.Println("OK ", filepath.Base(m.Source)) - } - - return nil -} - // collect all the valid looking migration scripts in the // migrations folder and go func registry, and key them by version -func CollectMigrations(dirpath string, current, target int64) (m Migrations, err error) { +func collectMigrations(dirpath string, current, target int64) (Migrations, error) { + var migrations Migrations // extract the numeric component of each migration, // filter out any uninteresting files, @@ -182,7 +95,7 @@ func CollectMigrations(dirpath string, current, target int64) (m Migrations, err } if versionFilter(v, current, target) { migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file} - m = append(m, migration) + migrations = append(migrations, migration) } } @@ -192,11 +105,24 @@ func CollectMigrations(dirpath string, current, target int64) (m Migrations, err return nil, err } if versionFilter(v, current, target) { - m = append(m, migration) + migrations = append(migrations, migration) } } - return m, nil + sort.Sort(migrations) + + // now that we're sorted in the appropriate direction, + // populate next and previous for each migration + for i, m := range migrations { + prev := int64(-1) + if i > 0 { + prev = migrations[i-1].Version + migrations[i-1].Next = m.Version + } + migrations[i].Previous = prev + } + + return migrations, nil } func versionFilter(v, current, target int64) bool { @@ -212,31 +138,6 @@ func versionFilter(v, current, target int64) bool { return false } -// look for migration scripts with names in the form: -// XXX_descriptivename.ext -// where XXX specifies the version number -// and ext specifies the type of migration -func NumericComponent(name string) (int64, error) { - - base := filepath.Base(name) - - if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { - return 0, errors.New("not a recognized migration file type") - } - - idx := strings.Index(base, "_") - if idx < 0 { - return 0, errors.New("no separator found") - } - - n, e := strconv.ParseInt(base[:idx], 10, 64) - if e == nil && n <= 0 { - return 0, errors.New("migration IDs must be greater than zero") - } - - return n, e -} - // retrieve the current version for this DB. // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { @@ -319,68 +220,3 @@ func GetDBVersion(db *sql.DB) (int64, error) { return version, nil } - -func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { - - if migrationType != "go" && migrationType != "sql" { - return "", errors.New("migration type must be 'go' or 'sql'") - } - - timestamp := t.Format("20060102150405") - filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType) - - fpath := filepath.Join(dir, filename) - tmpl := sqlMigrationTemplate - if migrationType == "go" { - tmpl = goSqlMigrationTemplate - } - - path, err = writeTemplateToFile(fpath, tmpl, timestamp) - - return -} - -// Update the version table for the given migration, -// and finalize the transaction. -func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error { - - // XXX: drop goose_db_version table on some minimum version number? - stmt := GetDialect().insertVersionSql() - if _, err := tx.Exec(stmt, v, direction); err != nil { - tx.Rollback() - return err - } - - return tx.Commit() -} - -var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(` --- +goose Up --- SQL in section 'Up' is executed when this migration is applied - - --- +goose Down --- SQL section 'Down' is executed when this migration is rolled back - -`)) -var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(` -package migration - -import ( - "database/sql" - - "github.com/pressly/goose" -) - -func init() { - goose.AddMigration(Up_{{.}}, Down_{{.}}) -} - -func Up_{{.}}(tx *sql.Tx) error { - return nil -} - -func Down_{{.}}(tx *sql.Tx) error { - return nil -} -`)) diff --git a/migration.go b/migration.go new file mode 100644 index 0000000..e014c57 --- /dev/null +++ b/migration.go @@ -0,0 +1,165 @@ +package goose + +import ( + "database/sql" + "errors" + "fmt" + "log" + "path/filepath" + "strconv" + "strings" + "text/template" + "time" +) + +type MigrationRecord struct { + VersionId int64 + TStamp time.Time + IsApplied bool // was this a result of up() or down() +} + +type Migration struct { + Version int64 + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + Source string // path to .sql script + UpFn func(*sql.Tx) error // Up go migration function + DownFn func(*sql.Tx) error // Down go migration function +} + +func (m *Migration) String() string { + return fmt.Sprintf(m.Source) +} + +func (m *Migration) Up(db *sql.DB) error { + return m.run(db, true) +} + +func (m *Migration) Down(db *sql.DB) error { + return m.run(db, false) +} + +func (m *Migration) run(db *sql.DB, direction bool) error { + switch filepath.Ext(m.Source) { + case ".sql": + if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil { + return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err)) + } + + case ".go": + tx, err := db.Begin() + if err != nil { + log.Fatal("db.Begin: ", err) + } + + fn := m.UpFn + if !direction { + fn = m.DownFn + } + if fn != nil { + if err := fn(tx); err != nil { + tx.Rollback() + log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(m.Source), err) + return err + } + } + + if err = FinalizeMigration(tx, direction, m.Version); err != nil { + log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(m.Source), err) + } + } + + fmt.Println("OK ", filepath.Base(m.Source)) + + return nil +} + +// look for migration scripts with names in the form: +// XXX_descriptivename.ext +// where XXX specifies the version number +// and ext specifies the type of migration +func NumericComponent(name string) (int64, error) { + + base := filepath.Base(name) + + if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { + return 0, errors.New("not a recognized migration file type") + } + + idx := strings.Index(base, "_") + if idx < 0 { + return 0, errors.New("no separator found") + } + + n, e := strconv.ParseInt(base[:idx], 10, 64) + if e == nil && n <= 0 { + return 0, errors.New("migration IDs must be greater than zero") + } + + return n, e +} + +func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { + + if migrationType != "go" && migrationType != "sql" { + return "", errors.New("migration type must be 'go' or 'sql'") + } + + timestamp := t.Format("20060102150405") + filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType) + + fpath := filepath.Join(dir, filename) + tmpl := sqlMigrationTemplate + if migrationType == "go" { + tmpl = goSqlMigrationTemplate + } + + path, err = writeTemplateToFile(fpath, tmpl, timestamp) + + return +} + +// Update the version table for the given migration, +// and finalize the transaction. +func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error { + + // XXX: drop goose_db_version table on some minimum version number? + stmt := GetDialect().insertVersionSql() + if _, err := tx.Exec(stmt, v, direction); err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + +var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(` +-- +goose Up +-- SQL in section 'Up' is executed when this migration is applied + + +-- +goose Down +-- SQL section 'Down' is executed when this migration is rolled back + +`)) +var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(` +package migration + +import ( + "database/sql" + + "github.com/pressly/goose" +) + +func init() { + goose.AddMigration(Up_{{.}}, Down_{{.}}) +} + +func Up_{{.}}(tx *sql.Tx) error { + return nil +} + +func Down_{{.}}(tx *sql.Tx) error { + return nil +} +`)) diff --git a/redo.go b/redo.go index 6a3a204..b78b081 100644 --- a/redo.go +++ b/redo.go @@ -5,27 +5,31 @@ import ( ) func Redo(db *sql.DB, dir string) error { - current, err := GetDBVersion(db) + currentVersion, err := GetDBVersion(db) if err != nil { return err } - migrations, err := CollectMigrations(dir, minVersion, maxVersion) - if err != nil { - return err - } - migrations.Sort(false) // descending, Next will be Previous - - previous, err := migrations.Next(current) + migrations, err := collectMigrations(dir, minVersion, maxVersion) if err != nil { return err } - if err := RunMigrations(db, dir, previous); err != nil { + current, err := migrations.Current(currentVersion) + if err != nil { return err } - if err := RunMigrations(db, dir, current); err != nil { + previous, err := migrations.Next(currentVersion) + if err != nil { + return err + } + + if err := previous.Up(db); err != nil { + return err + } + + if err := current.Up(db); err != nil { return err } diff --git a/status.go b/status.go index 9a8fbbc..6819712 100644 --- a/status.go +++ b/status.go @@ -10,18 +10,16 @@ import ( func Status(db *sql.DB, dir string) error { // collect all migrations - migrations, err := CollectMigrations(dir, minVersion, maxVersion) + migrations, err := collectMigrations(dir, minVersion, maxVersion) if err != nil { return err } - migrations.Sort(true) // must ensure that the version table exists if we're running on a pristine DB if _, err := EnsureDBVersion(db); err != nil { return err } - fmt.Println("goose: status") fmt.Println(" Applied At Migration") fmt.Println(" =======================================") for _, migration := range migrations { diff --git a/up.go b/up.go index 9ff520d..0bec004 100644 --- a/up.go +++ b/up.go @@ -6,44 +6,53 @@ import ( ) func Up(db *sql.DB, dir string) error { - migrations, err := CollectMigrations(dir, minVersion, maxVersion) - if err != nil { - return err - } - migrations.Sort(true) - - target, err := migrations.Last() + migrations, err := collectMigrations(dir, minVersion, maxVersion) if err != nil { return err } - if err := RunMigrations(db, dir, target); err != nil { - return err + for { + current, err := GetDBVersion(db) + if err != nil { + return err + } + + next, err := migrations.Next(current) + if err != nil { + if err == ErrNoNextVersion { + fmt.Printf("goose: no migrations to run. current version: %d\n", current) + } + return err + } + + if err = next.Up(db); err != nil { + return err + } } + return nil } func UpByOne(db *sql.DB, dir string) error { - migrations, err := CollectMigrations(dir, minVersion, maxVersion) - if err != nil { - return err - } - migrations.Sort(true) - - current, err := GetDBVersion(db) + migrations, err := collectMigrations(dir, minVersion, maxVersion) if err != nil { return err } - next, err := migrations.Next(current) + currentVersion, err := GetDBVersion(db) + if err != nil { + return err + } + + next, err := migrations.Next(currentVersion) if err != nil { if err == ErrNoNextVersion { - fmt.Printf("goose: no migrations to run. current version: %d\n", current) + fmt.Printf("goose: no migrations to run. current version: %d\n", currentVersion) } return err } - if err = RunMigrations(db, dir, next); err != nil { + if err = next.Up(db); err != nil { return err }