diff --git a/cmd/goose/main.go b/cmd/goose/main.go index dcc5910..87b9506 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -26,8 +26,8 @@ func main() { args := flags.Args() - if len(args) > 1 && args[0] == "create" { - if err := goose.Run("create", nil, *dir, args[1:]...); err != nil { + if len(args) > 1 && (args[0] == "create" || args[0] == "fix") { + if err := goose.Run(args[0], nil, *dir, args[1:]...); err != nil { log.Fatalf("goose run: %v", err) } return diff --git a/fix.go b/fix.go index caa9ff8..34f5705 100644 --- a/fix.go +++ b/fix.go @@ -1,25 +1,24 @@ package goose import ( - "database/sql" "fmt" "os" "strings" ) -func Fix(db *sql.DB, dir string) error { +func Fix(dir string) error { migrations, err := CollectMigrations(dir, minVersion, maxVersion) if err != nil { return err } // split into timestamped and versioned migrations - tsMigrations, err := migrations.Timestamped() + tsMigrations, err := migrations.timestamped() if err != nil { return err } - vMigrations, err := migrations.Versioned() + vMigrations, err := migrations.versioned() if err != nil { return err } @@ -29,12 +28,6 @@ func Fix(db *sql.DB, dir string) error { version = last.Version + 1 } - // fix db table as well - tx, err := db.Begin() - if err != nil { - log.Fatal("db.Begin: ", err) - } - // fix filenames by replacing timestamps with sequential versions for _, tsm := range tsMigrations { oldPath := tsm.Source @@ -44,13 +37,8 @@ func Fix(db *sql.DB, dir string) error { return err } - if _, err := tx.Exec(GetDialect().updateVersionSQL(), version, tsm.Version); err != nil { - tx.Rollback() - return err - } - version++ } - return tx.Commit() + return nil } diff --git a/goose.go b/goose.go index 1c492b0..ff4b5e8 100644 --- a/goose.go +++ b/goose.go @@ -66,7 +66,7 @@ func Run(command string, db *sql.DB, dir string, args ...string) error { return err } case "fix": - if err := Fix(db, dir); err != nil { + if err := Fix(dir); err != nil { return err } case "redo": diff --git a/migrate.go b/migrate.go index f998b27..5529442 100644 --- a/migrate.go +++ b/migrate.go @@ -78,7 +78,7 @@ func (ms Migrations) Last() (*Migration, error) { } // Versioned gets versioned migrations. -func (ms Migrations) Versioned() (Migrations, error) { +func (ms Migrations) versioned() (Migrations, error) { var migrations Migrations // assume that the user will never have more than 19700101000000 migrations @@ -95,7 +95,7 @@ func (ms Migrations) Versioned() (Migrations, error) { } // Timestamped gets the timestamped migrations. -func (ms Migrations) Timestamped() (Migrations, error) { +func (ms Migrations) timestamped() (Migrations, error) { var migrations Migrations // assume that the user will never have more than 19700101000000 migrations