add fix and timestamp default

pull/120/head
1vn 2018-10-30 16:45:45 -04:00
parent a4bf952640
commit 45eeb19d7d
6 changed files with 191 additions and 12 deletions

View File

@ -6,22 +6,12 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"text/template" "text/template"
"time"
) )
// Create writes a new blank migration file. // Create writes a new blank migration file.
func CreateWithTemplate(db *sql.DB, dir string, migrationTemplate *template.Template, name, migrationType string) error { func CreateWithTemplate(db *sql.DB, dir string, migrationTemplate *template.Template, name, migrationType string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion) version := time.Now().Format(timestampFormat)
if err != nil {
return err
}
// Initial version.
version := "00001"
if last, err := migrations.Last(); err == nil {
version = fmt.Sprintf("%05v", last.Version+1)
}
filename := fmt.Sprintf("%v_%v.%v", version, name, migrationType) filename := fmt.Sprintf("%v_%v.%v", version, name, migrationType)
fpath := filepath.Join(dir, filename) fpath := filepath.Join(dir, filename)

View File

@ -61,6 +61,10 @@ func (pg PostgresDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
} }
func (pg PostgresDialect) updateVersionSQL() string {
return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName())
}
func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil { if err != nil {
@ -91,6 +95,10 @@ func (m MySQLDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
} }
func (m MySQLDialect) updateVersionSQL() string {
return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName())
}
func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil { if err != nil {
@ -120,6 +128,10 @@ func (m Sqlite3Dialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
} }
func (m Sqlite3Dialect) updateVersionSQL() string {
return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName())
}
func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil { if err != nil {
@ -150,6 +162,10 @@ func (rs RedshiftDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
} }
func (rs RedshiftDialect) updateVersionSQL() string {
return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName())
}
func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil { if err != nil {
@ -180,6 +196,10 @@ func (m TiDBDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
} }
func (m TiDBDialect) updateVersionSQL() string {
return fmt.Sprintf("UPDATE %s SET version_id=? WHERE version_id=?;", TableName())
}
func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil { if err != nil {

45
fix.go Normal file
View File

@ -0,0 +1,45 @@
package goose
import (
"database/sql"
"fmt"
"os"
"strings"
)
func Fix(db *sql.DB, dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil {
return err
}
// split into timestamped and versioned migrations
tsMigrations, err := migrations.Timestamped()
if err != nil {
return err
}
vMigrations, err := migrations.Versioned()
if err != nil {
return err
}
// Initial version.
version := int64(1)
if last, err := vMigrations.Last(); err == nil {
version = last.Version + 1
}
// fix filenames by replacing timestamps with sequential versions
for _, tsm := range tsMigrations {
oldPath := tsm.Source
newPath := strings.Replace(oldPath, fmt.Sprintf("%d", tsm.Version), fmt.Sprintf("%05v", version), 1)
if err := os.Rename(oldPath, newPath); err != nil {
return err
}
version++
}
return nil
}

81
fix_test.go Normal file
View File

@ -0,0 +1,81 @@
package goose
import (
"fmt"
"io/ioutil"
"os"
"os/exec"
"strings"
"testing"
"time"
)
func TestFix(t *testing.T) {
dir, err := ioutil.TempDir("", "tmptest")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir) // clean up
defer os.Remove("goose") // clean up
commands := []string{
"go build -i -o goose ./cmd/goose",
fmt.Sprintf("./goose -dir=%s create yolo", dir),
fmt.Sprintf("./goose -dir=%s create yolo", dir),
fmt.Sprintf("./goose -dir=%s create yolo", dir),
fmt.Sprintf("./goose -dir=%s create yolo", dir),
fmt.Sprintf("./goose -dir=%s sqlite3 sql.db fix", dir),
}
for _, cmd := range commands {
args := strings.Split(cmd, " ")
time.Sleep(1 * time.Second)
out, err := exec.Command(args[0], args[1:]...).CombinedOutput()
if err != nil {
t.Fatalf("%s:\n%v\n\n%s", err, cmd, out)
}
}
files, err := ioutil.ReadDir(dir)
if err != nil {
t.Fatal(err)
}
// check that the files are in order
for i, f := range files {
expected := fmt.Sprintf("%05v", i+1)
if !strings.HasPrefix(f.Name(), expected) {
t.Errorf("failed to find %s prefix in %s", expected, f.Name())
}
}
// add more migrations and then fix it
commands = []string{
fmt.Sprintf("./goose -dir=%s create yolo", dir),
fmt.Sprintf("./goose -dir=%s create yolo", dir),
fmt.Sprintf("./goose -dir=%s sqlite3 sql.db fix", dir),
}
for _, cmd := range commands {
args := strings.Split(cmd, " ")
time.Sleep(1 * time.Second)
out, err := exec.Command(args[0], args[1:]...).CombinedOutput()
if err != nil {
t.Fatalf("%s:\n%v\n\n%s", err, cmd, out)
}
}
files, err = ioutil.ReadDir(dir)
if err != nil {
t.Fatal(err)
}
// check that the files still in order
for i, f := range files {
expected := fmt.Sprintf("%05v", i+1)
if !strings.HasPrefix(f.Name(), expected) {
t.Errorf("failed to find %s prefix in %s", expected, f.Name())
}
}
}

View File

@ -11,6 +11,7 @@ var (
duplicateCheckOnce sync.Once duplicateCheckOnce sync.Once
minVersion = int64(0) minVersion = int64(0)
maxVersion = int64((1 << 63) - 1) maxVersion = int64((1 << 63) - 1)
timestampFormat = "20060102150405"
) )
// Run runs a goose command. // Run runs a goose command.
@ -64,6 +65,10 @@ func Run(command string, db *sql.DB, dir string, args ...string) error {
if err := DownTo(db, dir, version); err != nil { if err := DownTo(db, dir, version); err != nil {
return err return err
} }
case "fix":
if err := Fix(db, dir); err != nil {
return err
}
case "redo": case "redo":
if err := Redo(db, dir); err != nil { if err := Redo(db, dir); err != nil {
return err return err

View File

@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort" "sort"
"time"
) )
var ( var (
@ -76,6 +77,43 @@ func (ms Migrations) Last() (*Migration, error) {
return ms[len(ms)-1], nil return ms[len(ms)-1], nil
} }
// Versioned gets versioned migrations.
func (ms Migrations) Versioned() (Migrations, error) {
var migrations Migrations
// assume that the user will never have more than 19700101000000 migrations
for _, m := range ms {
// parse version as timestmap
versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
if versionTime.Before(time.Unix(0, 0)) || err != nil {
migrations = append(migrations, m)
}
}
return migrations, nil
}
// Timestamped gets the timestamped migrations.
func (ms Migrations) Timestamped() (Migrations, error) {
var migrations Migrations
// assume that the user will never have more than 19700101000000 migrations
for _, m := range ms {
// parse version as timestmap
versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
if err != nil {
// probably not a timestamp
continue
}
if versionTime.After(time.Unix(0, 0)) {
migrations = append(migrations, m)
}
}
return migrations, nil
}
func (ms Migrations) String() string { func (ms Migrations) String() string {
str := "" str := ""
for _, m := range ms { for _, m := range ms {