diff --git a/down.go b/down.go index f02f774..944a3fc 100644 --- a/down.go +++ b/down.go @@ -1,7 +1,6 @@ package main import ( - "database/sql" "fmt" "log" "os" @@ -39,22 +38,6 @@ func downRun(cmd *Command, args ...string) { runMigrations(conf, conf.MigrationsDir, previous) } -func getDBVersion(conf *DBConf) int { - - db, err := sql.Open(conf.Driver, conf.OpenStr) - if err != nil { - log.Fatal("couldn't open DB:", err) - } - defer db.Close() - - version, err := ensureDBVersion(db) - if err != nil { - log.Fatalf("couldn't get DB version: %v", err) - } - - return version -} - func getPreviousVersion(dirpath string, version int) (previous, earliest int) { previous = -1 diff --git a/migrate.go b/migrate.go index 93fe57b..7c3e188 100644 --- a/migrate.go +++ b/migrate.go @@ -232,3 +232,21 @@ func ensureDBVersion(db *sql.DB) (int, error) { return 0, txn.Commit() } + +// wrapper for ensureDBVersion for callers that don't already have +// their own DB instance +func getDBVersion(conf *DBConf) int { + + db, err := sql.Open(conf.Driver, conf.OpenStr) + if err != nil { + log.Fatal("couldn't open DB:", err) + } + defer db.Close() + + version, err := ensureDBVersion(db) + if err != nil { + log.Fatalf("couldn't get DB version: %v", err) + } + + return version +}