diff --git a/cmd_status.go b/cmd_status.go index 2000fa4..38a45a6 100644 --- a/cmd_status.go +++ b/cmd_status.go @@ -35,7 +35,7 @@ func statusRun(cmd *Command, args ...string) { log.Fatal(e) } - db, e := sql.Open(conf.Driver, conf.OpenStr) + db, e := sql.Open(conf.Driver.Name, conf.Driver.OpenStr) if e != nil { log.Fatal("couldn't open DB:", e) } diff --git a/dbconf.go b/dbconf.go index 919a356..6d7486c 100644 --- a/dbconf.go +++ b/dbconf.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "github.com/kylelemons/go-gypsy/yaml" @@ -13,11 +14,18 @@ import ( var dbPath = flag.String("path", "db", "folder containing db info") var dbEnv = flag.String("env", "development", "which DB environment to use") +// DBDriver encapsulates the info needed to work with +// a specific database driver +type DBDriver struct { + Name string + OpenStr string + Import string +} + type DBConf struct { MigrationsDir string Env string - Driver string - OpenStr string + Driver DBDriver } // default helper - makes a DBConf from the dbPath and dbEnv flags @@ -55,10 +63,47 @@ func makeDBConfDetails(p, env string) (*DBConf, error) { } } + d := NewDBDriver(drv, open) + + // XXX: allow an import entry to override DBDriver.Import + + if !d.IsValid() { + return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d)) + } + return &DBConf{ MigrationsDir: filepath.Join(p, "migrations"), Env: env, - Driver: drv, - OpenStr: open, + Driver: d, }, nil } + +// Create a new DBDriver and populate driver specific +// fields for drivers that we know about. +// Further customization may be done in NewDBConf +func NewDBDriver(name, open string) DBDriver { + + d := DBDriver{ + Name: name, + OpenStr: open, + } + + switch name { + case "postgres": + d.Import = "github.com/lib/pq" + + case "mymysql": + d.Import = "github.com/ziutek/mymysql/godrv" + } + + return d +} + +// ensure we have enough info about this driver +func (drv *DBDriver) IsValid() bool { + if len(drv.Import) == 0 { + return false + } + + return true +} diff --git a/dbconf_test.go b/dbconf_test.go index a624d33..3aa461f 100644 --- a/dbconf_test.go +++ b/dbconf_test.go @@ -12,7 +12,7 @@ func TestBasics(t *testing.T) { t.Error("couldn't create DBConf") } - got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver, dbconf.OpenStr} + got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr} want := []string{"db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"} for i, s := range got { diff --git a/migrate.go b/migrate.go index 32bb9b0..16ce0fd 100644 --- a/migrate.go +++ b/migrate.go @@ -42,7 +42,7 @@ type MigrationMap struct { func runMigrations(conf *DBConf, migrationsDir string, target int64) { - db, err := sql.Open(conf.Driver, conf.OpenStr) + db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr) if err != nil { log.Fatal("couldn't open DB:", err) } @@ -267,7 +267,7 @@ func createVersionTable(conf *DBConf, db *sql.DB) error { // their own DB instance func getDBVersion(conf *DBConf) int64 { - db, err := sql.Open(conf.Driver, conf.OpenStr) + db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr) if err != nil { log.Fatal("couldn't open DB:", err) } diff --git a/migration_go.go b/migration_go.go index 6127b25..43e29e5 100644 --- a/migration_go.go +++ b/migration_go.go @@ -12,8 +12,7 @@ import ( type TemplateData struct { Version int64 - DBDriver string - DBOpen string + Driver DBDriver Direction bool Func string } @@ -41,8 +40,7 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er td := &TemplateData{ Version: version, - DBDriver: conf.Driver, - DBOpen: conf.OpenStr, + Driver: conf.Driver, Direction: direction, Func: fmt.Sprintf("%v_%v", directionStr, version), } @@ -76,13 +74,13 @@ package main import ( "database/sql" - _ "github.com/lib/pq" + _ "{{.Driver.Import}}" "log" "fmt" ) func main() { - db, err := sql.Open("{{.DBDriver}}", "{{.DBOpen}}") + db, err := sql.Open("{{.Driver.Name}}", "{{.Driver.OpenStr}}") if err != nil { log.Fatal("failed to open DB:", err) }