From 161a0a1cc2e1d58b70d2a1bb68a18b5cf3a0742c Mon Sep 17 00:00:00 2001 From: Liam Staskawicz Date: Sun, 31 Mar 2013 12:22:52 -0700 Subject: [PATCH] reorg: plumb dbconf through to consumers that need it for dialect-specific operations --- cmd_status.go | 2 +- migrate.go | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cmd_status.go b/cmd_status.go index e30ec9d..2000fa4 100644 --- a/cmd_status.go +++ b/cmd_status.go @@ -42,7 +42,7 @@ func statusRun(cmd *Command, args ...string) { defer db.Close() // must ensure that the version table exists if we're running on a pristine DB - if _, e := ensureDBVersion(db); e != nil { + if _, e := ensureDBVersion(conf, db); e != nil { log.Fatal(e) } diff --git a/migrate.go b/migrate.go index 4911597..32bb9b0 100644 --- a/migrate.go +++ b/migrate.go @@ -48,7 +48,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) { } defer db.Close() - current, e := ensureDBVersion(db) + current, e := ensureDBVersion(conf, db) if e != nil { log.Fatalf("couldn't get DB version: %v", e) } @@ -190,13 +190,13 @@ func numericComponent(name string) (int64, error) { // retrieve the current version for this DB. // Create and initialize the DB version table if it doesn't exist. -func ensureDBVersion(db *sql.DB) (int64, error) { +func ensureDBVersion(conf *DBConf, db *sql.DB) (int64, error) { rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC;") if err != nil { // XXX: cross platform method to detect failure reason // for now, assume it was because the table didn't exist, and try to create it - return 0, createVersionTable(db) + return 0, createVersionTable(conf, db) } // The most recent record for each migration specifies @@ -235,7 +235,9 @@ func ensureDBVersion(db *sql.DB) (int64, error) { panic("failure in ensureDBVersion()") } -func createVersionTable(db *sql.DB) error { +// Create the goose_db_version table +// and insert the initial 0 value into it +func createVersionTable(conf *DBConf, db *sql.DB) error { txn, err := db.Begin() if err != nil { return err @@ -271,7 +273,7 @@ func getDBVersion(conf *DBConf) int64 { } defer db.Close() - version, err := ensureDBVersion(db) + version, err := ensureDBVersion(conf, db) if err != nil { log.Fatalf("couldn't get DB version: %v", err) }