goose/migrate.go
Nicholas Duffy 0d5a6db9db Update CollectMigrations for subdirectories
I like to group my migrations by logical subdirectories.

```
[nicholasduffy@duffn:~/go/src/github.com/pressly/goose/migrations on sql-subdirectories]
% tree                                                                                                                                              ✭
.
├── group1
│   └── 20170506073854_table1_create_table.sql
└── group2
    └── 20170506073920_table2_create_table.sql

2 directories, 2 files
```

It's easy to run migrations by groups if necessary. However, I'd like the
ability to run all migrations in all subdirectories at the same time. This
PR adds discovery of all SQL files in subdirectories within the `-dir`
directory.
2017-05-06 08:04:19 -06:00

242 lines
5.3 KiB
Go

package goose
import (
"database/sql"
"errors"
"fmt"
"log"
"path/filepath"
"runtime"
"sort"
)
var (
ErrNoCurrentVersion = errors.New("no current version found")
ErrNoNextVersion = errors.New("no next version found")
MaxVersion int64 = 9223372036854775807 // max(int64)
goMigrations []*Migration
)
type Migrations []*Migration
// helpers so we can use pkg sort
func (ms Migrations) Len() int { return len(ms) }
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms Migrations) Less(i, j int) bool {
if ms[i].Version == ms[j].Version {
log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source)
}
return ms[i].Version < ms[j].Version
}
func (ms Migrations) Current(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version == current {
return ms[i], nil
}
}
return nil, ErrNoCurrentVersion
}
func (ms Migrations) Next(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version > current {
return ms[i], nil
}
}
return nil, ErrNoNextVersion
}
func (ms Migrations) Previous(current int64) (*Migration, error) {
for i := len(ms)-1; i >= 0; i-- {
if ms[i].Version < current {
return ms[i], nil
}
}
return nil, ErrNoNextVersion
}
func (ms Migrations) Last() (*Migration, error) {
if len(ms) == 0 {
return nil, ErrNoNextVersion
}
return ms[len(ms)-1], nil
}
func (ms Migrations) String() string {
str := ""
for _, m := range ms {
str += fmt.Sprintln(m)
}
return str
}
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down)
}
func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := NumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
goMigrations = append(goMigrations, migration)
}
// CollectMigrations returns all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version.
func CollectMigrations(dirpath string, current, target int64) (Migrations, error) {
var migrations Migrations
// extract the numeric component of each migration,
// filter out any uninteresting files,
// and ensure we only have one file per migration version.
sqlMigrations, err := filepath.Glob(dirpath + "/**/*.sql")
if err != nil {
return nil, err
}
for _, file := range sqlMigrations {
v, err := NumericComponent(file)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
migrations = append(migrations, migration)
}
}
for _, migration := range goMigrations {
v, err := NumericComponent(migration.Source)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
}
}
migrations = sortAndConnectMigrations(migrations)
return migrations, nil
}
func sortAndConnectMigrations(migrations Migrations) Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
}
// retrieve the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
rows, err := GetDialect().dbVersionQuery(db)
if err != nil {
return 0, createVersionTable(db)
}
defer rows.Close()
// The most recent record for each migration specifies
// whether it has been applied or rolled back.
// The first version we find that has been applied is the current version.
toSkip := make([]int64, 0)
for rows.Next() {
var row MigrationRecord
if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil {
log.Fatal("error scanning rows:", err)
}
// have we already marked this version to be skipped?
skip := false
for _, v := range toSkip {
if v == row.VersionId {
skip = true
break
}
}
if skip {
continue
}
// if version has been applied we're done
if row.IsApplied {
return row.VersionId, nil
}
// latest version of migration has not been applied.
toSkip = append(toSkip, row.VersionId)
}
return 0, ErrNoNextVersion
}
// Create the goose_db_version table
// and insert the initial 0 value into it
func createVersionTable(db *sql.DB) error {
txn, err := db.Begin()
if err != nil {
return err
}
d := GetDialect()
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback()
return err
}
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
txn.Rollback()
return err
}
return txn.Commit()
}
// wrapper for EnsureDBVersion for callers that don't already have
// their own DB instance
func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db)
if err != nil {
return -1, err
}
return version, nil
}