From 45eeb19d7d952531c3102add91cad61f1d02a410 Mon Sep 17 00:00:00 2001 From: 1vn Date: Tue, 30 Oct 2018 16:45:45 -0400 Subject: [PATCH] add fix and timestamp default --- create.go | 14 ++------- dialect.go | 20 +++++++++++++ fix.go | 45 +++++++++++++++++++++++++++++ fix_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++ goose.go | 5 ++++ migrate.go | 38 +++++++++++++++++++++++++ 6 files changed, 191 insertions(+), 12 deletions(-) create mode 100644 fix.go create mode 100644 fix_test.go diff --git a/create.go b/create.go index b863076..6761dec 100644 --- a/create.go +++ b/create.go @@ -6,22 +6,12 @@ import ( "os" "path/filepath" "text/template" + "time" ) // Create writes a new blank migration file. func CreateWithTemplate(db *sql.DB, dir string, migrationTemplate *template.Template, name, migrationType string) error { - migrations, err := CollectMigrations(dir, minVersion, maxVersion) - if err != nil { - return err - } - - // Initial version. - version := "00001" - - if last, err := migrations.Last(); err == nil { - version = fmt.Sprintf("%05v", last.Version+1) - } - + version := time.Now().Format(timestampFormat) filename := fmt.Sprintf("%v_%v.%v", version, name, migrationType) fpath := filepath.Join(dir, filename) diff --git a/dialect.go b/dialect.go index 488f5e0..edc2c35 100644 --- a/dialect.go +++ b/dialect.go @@ -61,6 +61,10 @@ func (pg PostgresDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) } +func (pg PostgresDialect) updateVersionSQL() string { + return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName()) +} + func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { @@ -91,6 +95,10 @@ func (m MySQLDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) } +func (m MySQLDialect) updateVersionSQL() string { + return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName()) +} + func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { @@ -120,6 +128,10 @@ func (m Sqlite3Dialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) } +func (m Sqlite3Dialect) updateVersionSQL() string { + return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName()) +} + func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { @@ -150,6 +162,10 @@ func (rs RedshiftDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) } +func (rs RedshiftDialect) updateVersionSQL() string { + return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName()) +} + func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { @@ -180,6 +196,10 @@ func (m TiDBDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) } +func (m TiDBDialect) updateVersionSQL() string { + return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName()) +} + func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { diff --git a/fix.go b/fix.go new file mode 100644 index 0000000..abb9d19 --- /dev/null +++ b/fix.go @@ -0,0 +1,45 @@ +package goose + +import ( + "database/sql" + "fmt" + "os" + "strings" +) + +func Fix(db *sql.DB, dir string) error { + migrations, err := CollectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + + // split into timestamped and versioned migrations + tsMigrations, err := migrations.Timestamped() + if err != nil { + return err + } + + vMigrations, err := migrations.Versioned() + if err != nil { + return err + } + // Initial version. + version := int64(1) + if last, err := vMigrations.Last(); err == nil { + version = last.Version + 1 + } + + // fix filenames by replacing timestamps with sequential versions + for _, tsm := range tsMigrations { + oldPath := tsm.Source + newPath := strings.Replace(oldPath, fmt.Sprintf("%d", tsm.Version), fmt.Sprintf("%05v", version), 1) + + if err := os.Rename(oldPath, newPath); err != nil { + return err + } + + version++ + } + + return nil +} diff --git a/fix_test.go b/fix_test.go new file mode 100644 index 0000000..e2f9fc9 --- /dev/null +++ b/fix_test.go @@ -0,0 +1,81 @@ +package goose + +import ( + "fmt" + "io/ioutil" + "os" + "os/exec" + "strings" + "testing" + "time" +) + +func TestFix(t *testing.T) { + dir, err := ioutil.TempDir("", "tmptest") + if err != nil { + t.Fatal(err) + } + + defer os.RemoveAll(dir) // clean up + defer os.Remove("goose") // clean up + + commands := []string{ + "go build -i -o goose ./cmd/goose", + fmt.Sprintf("./goose -dir=%s create yolo", dir), + fmt.Sprintf("./goose -dir=%s create yolo", dir), + fmt.Sprintf("./goose -dir=%s create yolo", dir), + fmt.Sprintf("./goose -dir=%s create yolo", dir), + fmt.Sprintf("./goose -dir=%s sqlite3 sql.db fix", dir), + } + + for _, cmd := range commands { + args := strings.Split(cmd, " ") + time.Sleep(1 * time.Second) + out, err := exec.Command(args[0], args[1:]...).CombinedOutput() + if err != nil { + t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) + } + } + + files, err := ioutil.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + + // check that the files are in order + for i, f := range files { + expected := fmt.Sprintf("%05v", i+1) + if !strings.HasPrefix(f.Name(), expected) { + t.Errorf("failed to find %s prefix in %s", expected, f.Name()) + } + } + + // add more migrations and then fix it + commands = []string{ + fmt.Sprintf("./goose -dir=%s create yolo", dir), + fmt.Sprintf("./goose -dir=%s create yolo", dir), + fmt.Sprintf("./goose -dir=%s sqlite3 sql.db fix", dir), + } + + for _, cmd := range commands { + args := strings.Split(cmd, " ") + time.Sleep(1 * time.Second) + out, err := exec.Command(args[0], args[1:]...).CombinedOutput() + if err != nil { + t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) + } + } + + files, err = ioutil.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + + // check that the files still in order + for i, f := range files { + expected := fmt.Sprintf("%05v", i+1) + if !strings.HasPrefix(f.Name(), expected) { + t.Errorf("failed to find %s prefix in %s", expected, f.Name()) + } + } +} diff --git a/goose.go b/goose.go index 51ca6ec..1c492b0 100644 --- a/goose.go +++ b/goose.go @@ -11,6 +11,7 @@ var ( duplicateCheckOnce sync.Once minVersion = int64(0) maxVersion = int64((1 << 63) - 1) + timestampFormat = "20060102150405" ) // Run runs a goose command. @@ -64,6 +65,10 @@ func Run(command string, db *sql.DB, dir string, args ...string) error { if err := DownTo(db, dir, version); err != nil { return err } + case "fix": + if err := Fix(db, dir); err != nil { + return err + } case "redo": if err := Redo(db, dir); err != nil { return err diff --git a/migrate.go b/migrate.go index 4774af0..f998b27 100644 --- a/migrate.go +++ b/migrate.go @@ -8,6 +8,7 @@ import ( "path/filepath" "runtime" "sort" + "time" ) var ( @@ -76,6 +77,43 @@ func (ms Migrations) Last() (*Migration, error) { return ms[len(ms)-1], nil } +// Versioned gets versioned migrations. +func (ms Migrations) Versioned() (Migrations, error) { + var migrations Migrations + + // assume that the user will never have more than 19700101000000 migrations + for _, m := range ms { + // parse version as timestmap + versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version)) + + if versionTime.Before(time.Unix(0, 0)) || err != nil { + migrations = append(migrations, m) + } + } + + return migrations, nil +} + +// Timestamped gets the timestamped migrations. +func (ms Migrations) Timestamped() (Migrations, error) { + var migrations Migrations + + // assume that the user will never have more than 19700101000000 migrations + for _, m := range ms { + // parse version as timestmap + versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version)) + if err != nil { + // probably not a timestamp + continue + } + + if versionTime.After(time.Unix(0, 0)) { + migrations = append(migrations, m) + } + } + return migrations, nil +} + func (ms Migrations) String() string { str := "" for _, m := range ms {