diff --git a/cmd/goose/main.go b/cmd/goose/main.go index c589029..c59079e 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -116,6 +116,7 @@ Commands: down Roll back the version by 1 down-to VERSION Roll back to a specific VERSION redo Re-run the latest migration + reset Roll back all migrations status Dump the migration status for the current DB version Print the current version of the database create NAME [sql|go] Creates new migration file with next version diff --git a/goose.go b/goose.go index 3861e41..51ca6ec 100644 --- a/goose.go +++ b/goose.go @@ -68,6 +68,10 @@ func Run(command string, db *sql.DB, dir string, args ...string) error { if err := Redo(db, dir); err != nil { return err } + case "reset": + if err := Reset(db, dir); err != nil { + return err + } case "status": if err := Status(db, dir); err != nil { return err diff --git a/reset.go b/reset.go new file mode 100644 index 0000000..c7e0b4d --- /dev/null +++ b/reset.go @@ -0,0 +1,59 @@ +package goose + +import ( + "database/sql" + "log" + "sort" +) + +// Reset rolls back all migrations +func Reset(db *sql.DB, dir string) error { + migrations, err := CollectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + statuses, err := dbMigrationsStatus(db) + if err != nil { + return err + } + sort.Sort(sort.Reverse(migrations)) + + for _, migration := range migrations { + if !statuses[migration.Version] { + continue + } + if err = migration.Down(db); err != nil { + return err + } + } + + return nil +} + +func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) { + rows, err := GetDialect().dbVersionQuery(db) + if err != nil { + return map[int64]bool{}, createVersionTable(db) + } + defer rows.Close() + + // The most recent record for each migration specifies + // whether it has been applied or rolled back. + + result := make(map[int64]bool) + + for rows.Next() { + var row MigrationRecord + if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil { + log.Fatal("error scanning rows:", err) + } + + if _, ok := result[row.VersionID]; ok { + continue + } + + result[row.VersionID] = row.IsApplied + } + + return result, nil +}