Merge pull request #9 from pressly/refactor

Refactor
pull/10/head
Vojtech Vitek 2016-10-03 19:41:15 -04:00 committed by GitHub
commit cb330a2e11
7 changed files with 269 additions and 275 deletions

22
down.go
View File

@ -1,27 +1,25 @@
package goose
import "database/sql"
import (
"database/sql"
"fmt"
)
func Down(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
currentVersion, err := GetDBVersion(db)
if err != nil {
return err
}
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(false) // descending, Next will be Previous
previous, err := migrations.Next(current)
migrations, err := collectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
if err = RunMigrations(db, dir, previous); err != nil {
return err
current, err := migrations.Current(currentVersion)
if err != nil {
return fmt.Errorf("no migration %v", currentVersion)
}
return nil
return current.Down(db)
}

View File

@ -12,23 +12,7 @@ var (
maxVersion = int64((1 << 63) - 1)
)
func checkVersionDuplicates(dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
// try both directions
migrations.Sort(false)
migrations.Sort(true)
return nil
}
func Run(command string, db *sql.DB, dir string, args ...string) error {
if err := checkVersionDuplicates(dir); err != nil {
return err
}
switch command {
case "up":
if err := Up(db, dir); err != nil {

View File

@ -8,36 +8,17 @@ import (
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"text/template"
"time"
)
var (
ErrNoPreviousVersion = errors.New("no previous version found")
ErrNoNextVersion = errors.New("no next version found")
ErrNoCurrentVersion = errors.New("no current version found")
ErrNoNextVersion = errors.New("no next version found")
MaxVersion = 9223372036854775807 // max(int64)
goMigrations []*Migration
)
type MigrationRecord struct {
VersionId int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script
Up func(*sql.Tx) error // Up go migration function
Down func(*sql.Tx) error // Down go migration function
}
type Migrations []*Migration
// helpers so we can use pkg sort
@ -50,122 +31,54 @@ func (ms Migrations) Less(i, j int) bool {
return ms[i].Version < ms[j].Version
}
func (ms Migrations) Sort(up bool) {
// sort ascending or descending by version
if up {
sort.Sort(ms)
} else {
sort.Sort(sort.Reverse(ms))
}
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range ms {
prev := int64(-1)
if i > 0 {
prev = ms[i-1].Version
ms[i-1].Next = m.Version
}
ms[i].Previous = prev
}
}
func (ms Migrations) Last() (int64, error) {
if len(ms) == 0 {
return -1, ErrNoNextVersion
}
return ms[len(ms)-1].Version, nil
}
func (ms Migrations) Next(current int64) (int64, error) {
exceptLast := ms[:len(ms)-1]
for i, migration := range exceptLast {
func (ms Migrations) Current(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version == current {
return ms[i+1].Version, nil
return ms[i], nil
}
}
return -1, ErrNoNextVersion
return nil, ErrNoCurrentVersion
}
func (ms Migrations) Next(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version > current {
return ms[i], nil
}
}
return nil, ErrNoNextVersion
}
func (ms Migrations) Last() (*Migration, error) {
if len(ms) == 0 {
return nil, ErrNoNextVersion
}
return ms[len(ms)-1], nil
}
func (ms Migrations) String() string {
str := ""
for _, m := range ms {
str += fmt.Sprintln(m)
}
return str
}
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
v, _ := NumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, Up: up, Down: down, Source: filename}
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
goMigrations = append(goMigrations, migration)
}
func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
current, err := EnsureDBVersion(db)
if err != nil {
return err
}
if current == target {
fmt.Printf("goose: no migrations to run. current version: %d. target version: %d\n", current, target)
return nil
}
migrations, err := CollectMigrations(dir, current, target)
if err != nil {
return err
}
if len(migrations) == 0 {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
return nil
}
ms := Migrations(migrations)
direction := current < target
ms.Sort(direction)
fmt.Printf("goose: migrating db, current version: %d, target: %d\n", current, target)
for _, m := range ms {
switch filepath.Ext(m.Source) {
case ".sql":
if err = runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err))
}
case ".go":
tx, err := db.Begin()
if err != nil {
log.Fatal("db.Begin: ", err)
}
fn := m.Up
if !direction {
fn = m.Down
}
if fn != nil {
if err := fn(tx); err != nil {
tx.Rollback()
log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(m.Source), err)
return err
}
}
if err = FinalizeMigration(tx, direction, m.Version); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(m.Source), err)
}
}
fmt.Println("OK ", filepath.Base(m.Source))
}
return nil
}
// collect all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version
func CollectMigrations(dirpath string, current, target int64) (m Migrations, err error) {
func collectMigrations(dirpath string, current, target int64) (Migrations, error) {
var migrations Migrations
// extract the numeric component of each migration,
// filter out any uninteresting files,
@ -182,7 +95,7 @@ func CollectMigrations(dirpath string, current, target int64) (m Migrations, err
}
if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
m = append(m, migration)
migrations = append(migrations, migration)
}
}
@ -192,11 +105,24 @@ func CollectMigrations(dirpath string, current, target int64) (m Migrations, err
return nil, err
}
if versionFilter(v, current, target) {
m = append(m, migration)
migrations = append(migrations, migration)
}
}
return m, nil
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations, nil
}
func versionFilter(v, current, target int64) bool {
@ -212,31 +138,6 @@ func versionFilter(v, current, target int64) bool {
return false
}
// look for migration scripts with names in the form:
// XXX_descriptivename.ext
// where XXX specifies the version number
// and ext specifies the type of migration
func NumericComponent(name string) (int64, error) {
base := filepath.Base(name)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no separator found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
}
return n, e
}
// retrieve the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
@ -319,68 +220,3 @@ func GetDBVersion(db *sql.DB) (int64, error) {
return version, nil
}
func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {
if migrationType != "go" && migrationType != "sql" {
return "", errors.New("migration type must be 'go' or 'sql'")
}
timestamp := t.Format("20060102150405")
filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType)
fpath := filepath.Join(dir, filename)
tmpl := sqlMigrationTemplate
if migrationType == "go" {
tmpl = goSqlMigrationTemplate
}
path, err = writeTemplateToFile(fpath, tmpl, timestamp)
return
}
// Update the version table for the given migration,
// and finalize the transaction.
func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := GetDialect().insertVersionSql()
if _, err := tx.Exec(stmt, v, direction); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`
-- +goose Up
-- SQL in section 'Up' is executed when this migration is applied
-- +goose Down
-- SQL section 'Down' is executed when this migration is rolled back
`))
var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
package migration
import (
"database/sql"
"github.com/pressly/goose"
)
func init() {
goose.AddMigration(Up_{{.}}, Down_{{.}})
}
func Up_{{.}}(tx *sql.Tx) error {
return nil
}
func Down_{{.}}(tx *sql.Tx) error {
return nil
}
`))

165
migration.go Normal file
View File

@ -0,0 +1,165 @@
package goose
import (
"database/sql"
"errors"
"fmt"
"log"
"path/filepath"
"strconv"
"strings"
"text/template"
"time"
)
type MigrationRecord struct {
VersionId int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script
UpFn func(*sql.Tx) error // Up go migration function
DownFn func(*sql.Tx) error // Down go migration function
}
func (m *Migration) String() string {
return fmt.Sprintf(m.Source)
}
func (m *Migration) Up(db *sql.DB) error {
return m.run(db, true)
}
func (m *Migration) Down(db *sql.DB) error {
return m.run(db, false)
}
func (m *Migration) run(db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err))
}
case ".go":
tx, err := db.Begin()
if err != nil {
log.Fatal("db.Begin: ", err)
}
fn := m.UpFn
if !direction {
fn = m.DownFn
}
if fn != nil {
if err := fn(tx); err != nil {
tx.Rollback()
log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(m.Source), err)
return err
}
}
if err = FinalizeMigration(tx, direction, m.Version); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(m.Source), err)
}
}
fmt.Println("OK ", filepath.Base(m.Source))
return nil
}
// look for migration scripts with names in the form:
// XXX_descriptivename.ext
// where XXX specifies the version number
// and ext specifies the type of migration
func NumericComponent(name string) (int64, error) {
base := filepath.Base(name)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no separator found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
}
return n, e
}
func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {
if migrationType != "go" && migrationType != "sql" {
return "", errors.New("migration type must be 'go' or 'sql'")
}
timestamp := t.Format("20060102150405")
filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType)
fpath := filepath.Join(dir, filename)
tmpl := sqlMigrationTemplate
if migrationType == "go" {
tmpl = goSqlMigrationTemplate
}
path, err = writeTemplateToFile(fpath, tmpl, timestamp)
return
}
// Update the version table for the given migration,
// and finalize the transaction.
func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := GetDialect().insertVersionSql()
if _, err := tx.Exec(stmt, v, direction); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`
-- +goose Up
-- SQL in section 'Up' is executed when this migration is applied
-- +goose Down
-- SQL section 'Down' is executed when this migration is rolled back
`))
var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
package migration
import (
"database/sql"
"github.com/pressly/goose"
)
func init() {
goose.AddMigration(Up_{{.}}, Down_{{.}})
}
func Up_{{.}}(tx *sql.Tx) error {
return nil
}
func Down_{{.}}(tx *sql.Tx) error {
return nil
}
`))

24
redo.go
View File

@ -5,27 +5,31 @@ import (
)
func Redo(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
currentVersion, err := GetDBVersion(db)
if err != nil {
return err
}
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(false) // descending, Next will be Previous
previous, err := migrations.Next(current)
migrations, err := collectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
if err := RunMigrations(db, dir, previous); err != nil {
current, err := migrations.Current(currentVersion)
if err != nil {
return err
}
if err := RunMigrations(db, dir, current); err != nil {
previous, err := migrations.Next(currentVersion)
if err != nil {
return err
}
if err := previous.Up(db); err != nil {
return err
}
if err := current.Up(db); err != nil {
return err
}

View File

@ -10,18 +10,16 @@ import (
func Status(db *sql.DB, dir string) error {
// collect all migrations
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
migrations, err := collectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)
// must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil {
return err
}
fmt.Println("goose: status")
fmt.Println(" Applied At Migration")
fmt.Println(" =======================================")
for _, migration := range migrations {

47
up.go
View File

@ -6,44 +6,53 @@ import (
)
func Up(db *sql.DB, dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)
target, err := migrations.Last()
migrations, err := collectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
if err := RunMigrations(db, dir, target); err != nil {
return err
for {
current, err := GetDBVersion(db)
if err != nil {
return err
}
next, err := migrations.Next(current)
if err != nil {
if err == ErrNoNextVersion {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
}
return err
}
if err = next.Up(db); err != nil {
return err
}
}
return nil
}
func UpByOne(db *sql.DB, dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
migrations.Sort(true)
current, err := GetDBVersion(db)
migrations, err := collectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
next, err := migrations.Next(current)
currentVersion, err := GetDBVersion(db)
if err != nil {
return err
}
next, err := migrations.Next(currentVersion)
if err != nil {
if err == ErrNoNextVersion {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
fmt.Printf("goose: no migrations to run. current version: %d\n", currentVersion)
}
return err
}
if err = RunMigrations(db, dir, next); err != nil {
if err = next.Up(db); err != nil {
return err
}