mirror of
https://github.com/pressly/goose.git
synced 2025-05-20 22:41:25 +00:00
I like to group my migrations by logical subdirectories. ``` [nicholasduffy@duffn:~/go/src/github.com/pressly/goose/migrations on sql-subdirectories] % tree ✭ . ├── group1 │ └── 20170506073854_table1_create_table.sql └── group2 └── 20170506073920_table2_create_table.sql 2 directories, 2 files ``` It's easy to run migrations by groups if necessary. However, I'd like the ability to run all migrations in all subdirectories at the same time. This PR adds discovery of all SQL files in subdirectories within the `-dir` directory.
242 lines
5.3 KiB
Go
242 lines
5.3 KiB
Go
package goose
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sort"
|
|
)
|
|
|
|
var (
|
|
ErrNoCurrentVersion = errors.New("no current version found")
|
|
ErrNoNextVersion = errors.New("no next version found")
|
|
|
|
MaxVersion int64 = 9223372036854775807 // max(int64)
|
|
|
|
goMigrations []*Migration
|
|
)
|
|
|
|
type Migrations []*Migration
|
|
|
|
// helpers so we can use pkg sort
|
|
func (ms Migrations) Len() int { return len(ms) }
|
|
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
|
|
func (ms Migrations) Less(i, j int) bool {
|
|
if ms[i].Version == ms[j].Version {
|
|
log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source)
|
|
}
|
|
return ms[i].Version < ms[j].Version
|
|
}
|
|
|
|
func (ms Migrations) Current(current int64) (*Migration, error) {
|
|
for i, migration := range ms {
|
|
if migration.Version == current {
|
|
return ms[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, ErrNoCurrentVersion
|
|
}
|
|
|
|
func (ms Migrations) Next(current int64) (*Migration, error) {
|
|
for i, migration := range ms {
|
|
if migration.Version > current {
|
|
return ms[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, ErrNoNextVersion
|
|
}
|
|
|
|
func (ms Migrations) Previous(current int64) (*Migration, error) {
|
|
for i := len(ms)-1; i >= 0; i-- {
|
|
if ms[i].Version < current {
|
|
return ms[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, ErrNoNextVersion
|
|
}
|
|
|
|
func (ms Migrations) Last() (*Migration, error) {
|
|
if len(ms) == 0 {
|
|
return nil, ErrNoNextVersion
|
|
}
|
|
|
|
return ms[len(ms)-1], nil
|
|
}
|
|
|
|
func (ms Migrations) String() string {
|
|
str := ""
|
|
for _, m := range ms {
|
|
str += fmt.Sprintln(m)
|
|
}
|
|
return str
|
|
}
|
|
|
|
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
|
_, filename, _, _ := runtime.Caller(1)
|
|
AddNamedMigration(filename, up, down)
|
|
}
|
|
|
|
func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
|
v, _ := NumericComponent(filename)
|
|
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
|
|
|
|
goMigrations = append(goMigrations, migration)
|
|
}
|
|
|
|
// CollectMigrations returns all the valid looking migration scripts in the
|
|
// migrations folder and go func registry, and key them by version.
|
|
func CollectMigrations(dirpath string, current, target int64) (Migrations, error) {
|
|
var migrations Migrations
|
|
|
|
// extract the numeric component of each migration,
|
|
// filter out any uninteresting files,
|
|
// and ensure we only have one file per migration version.
|
|
sqlMigrations, err := filepath.Glob(dirpath + "/**/*.sql")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, file := range sqlMigrations {
|
|
v, err := NumericComponent(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if versionFilter(v, current, target) {
|
|
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
|
|
migrations = append(migrations, migration)
|
|
}
|
|
}
|
|
|
|
for _, migration := range goMigrations {
|
|
v, err := NumericComponent(migration.Source)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if versionFilter(v, current, target) {
|
|
migrations = append(migrations, migration)
|
|
}
|
|
}
|
|
|
|
migrations = sortAndConnectMigrations(migrations)
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
func sortAndConnectMigrations(migrations Migrations) Migrations {
|
|
sort.Sort(migrations)
|
|
|
|
// now that we're sorted in the appropriate direction,
|
|
// populate next and previous for each migration
|
|
for i, m := range migrations {
|
|
prev := int64(-1)
|
|
if i > 0 {
|
|
prev = migrations[i-1].Version
|
|
migrations[i-1].Next = m.Version
|
|
}
|
|
migrations[i].Previous = prev
|
|
}
|
|
|
|
return migrations
|
|
}
|
|
|
|
func versionFilter(v, current, target int64) bool {
|
|
|
|
if target > current {
|
|
return v > current && v <= target
|
|
}
|
|
|
|
if target < current {
|
|
return v <= current && v > target
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// retrieve the current version for this DB.
|
|
// Create and initialize the DB version table if it doesn't exist.
|
|
func EnsureDBVersion(db *sql.DB) (int64, error) {
|
|
rows, err := GetDialect().dbVersionQuery(db)
|
|
if err != nil {
|
|
return 0, createVersionTable(db)
|
|
}
|
|
defer rows.Close()
|
|
|
|
// The most recent record for each migration specifies
|
|
// whether it has been applied or rolled back.
|
|
// The first version we find that has been applied is the current version.
|
|
|
|
toSkip := make([]int64, 0)
|
|
|
|
for rows.Next() {
|
|
var row MigrationRecord
|
|
if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil {
|
|
log.Fatal("error scanning rows:", err)
|
|
}
|
|
|
|
// have we already marked this version to be skipped?
|
|
skip := false
|
|
for _, v := range toSkip {
|
|
if v == row.VersionId {
|
|
skip = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if skip {
|
|
continue
|
|
}
|
|
|
|
// if version has been applied we're done
|
|
if row.IsApplied {
|
|
return row.VersionId, nil
|
|
}
|
|
|
|
// latest version of migration has not been applied.
|
|
toSkip = append(toSkip, row.VersionId)
|
|
}
|
|
|
|
return 0, ErrNoNextVersion
|
|
}
|
|
|
|
// Create the goose_db_version table
|
|
// and insert the initial 0 value into it
|
|
func createVersionTable(db *sql.DB) error {
|
|
txn, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
d := GetDialect()
|
|
|
|
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
|
|
txn.Rollback()
|
|
return err
|
|
}
|
|
|
|
version := 0
|
|
applied := true
|
|
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
|
|
txn.Rollback()
|
|
return err
|
|
}
|
|
|
|
return txn.Commit()
|
|
}
|
|
|
|
// wrapper for EnsureDBVersion for callers that don't already have
|
|
// their own DB instance
|
|
func GetDBVersion(db *sql.DB) (int64, error) {
|
|
version, err := EnsureDBVersion(db)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
return version, nil
|
|
}
|