dbconf: split out separate DBDriver struct to encapsulate the info required for a given driver. as a bonus, we now import the correct package for Go migrations based on the driver - previously, we imported postgres only.

pull/2/head
Liam Staskawicz 2013-04-07 14:14:11 -07:00
parent 161a0a1cc2
commit 4446df2ca6
5 changed files with 57 additions and 14 deletions

View File

@ -35,7 +35,7 @@ func statusRun(cmd *Command, args ...string) {
log.Fatal(e) log.Fatal(e)
} }
db, e := sql.Open(conf.Driver, conf.OpenStr) db, e := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
if e != nil { if e != nil {
log.Fatal("couldn't open DB:", e) log.Fatal("couldn't open DB:", e)
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"github.com/kylelemons/go-gypsy/yaml" "github.com/kylelemons/go-gypsy/yaml"
@ -13,11 +14,18 @@ import (
var dbPath = flag.String("path", "db", "folder containing db info") var dbPath = flag.String("path", "db", "folder containing db info")
var dbEnv = flag.String("env", "development", "which DB environment to use") var dbEnv = flag.String("env", "development", "which DB environment to use")
// DBDriver encapsulates the info needed to work with
// a specific database driver
type DBDriver struct {
Name string
OpenStr string
Import string
}
type DBConf struct { type DBConf struct {
MigrationsDir string MigrationsDir string
Env string Env string
Driver string Driver DBDriver
OpenStr string
} }
// default helper - makes a DBConf from the dbPath and dbEnv flags // default helper - makes a DBConf from the dbPath and dbEnv flags
@ -55,10 +63,47 @@ func makeDBConfDetails(p, env string) (*DBConf, error) {
} }
} }
d := NewDBDriver(drv, open)
// XXX: allow an import entry to override DBDriver.Import
if !d.IsValid() {
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
}
return &DBConf{ return &DBConf{
MigrationsDir: filepath.Join(p, "migrations"), MigrationsDir: filepath.Join(p, "migrations"),
Env: env, Env: env,
Driver: drv, Driver: d,
OpenStr: open,
}, nil }, nil
} }
// Create a new DBDriver and populate driver specific
// fields for drivers that we know about.
// Further customization may be done in NewDBConf
func NewDBDriver(name, open string) DBDriver {
d := DBDriver{
Name: name,
OpenStr: open,
}
switch name {
case "postgres":
d.Import = "github.com/lib/pq"
case "mymysql":
d.Import = "github.com/ziutek/mymysql/godrv"
}
return d
}
// ensure we have enough info about this driver
func (drv *DBDriver) IsValid() bool {
if len(drv.Import) == 0 {
return false
}
return true
}

View File

@ -12,7 +12,7 @@ func TestBasics(t *testing.T) {
t.Error("couldn't create DBConf") t.Error("couldn't create DBConf")
} }
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver, dbconf.OpenStr} got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr}
want := []string{"db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"} want := []string{"db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
for i, s := range got { for i, s := range got {

View File

@ -42,7 +42,7 @@ type MigrationMap struct {
func runMigrations(conf *DBConf, migrationsDir string, target int64) { func runMigrations(conf *DBConf, migrationsDir string, target int64) {
db, err := sql.Open(conf.Driver, conf.OpenStr) db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
if err != nil { if err != nil {
log.Fatal("couldn't open DB:", err) log.Fatal("couldn't open DB:", err)
} }
@ -267,7 +267,7 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
// their own DB instance // their own DB instance
func getDBVersion(conf *DBConf) int64 { func getDBVersion(conf *DBConf) int64 {
db, err := sql.Open(conf.Driver, conf.OpenStr) db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
if err != nil { if err != nil {
log.Fatal("couldn't open DB:", err) log.Fatal("couldn't open DB:", err)
} }

View File

@ -12,8 +12,7 @@ import (
type TemplateData struct { type TemplateData struct {
Version int64 Version int64
DBDriver string Driver DBDriver
DBOpen string
Direction bool Direction bool
Func string Func string
} }
@ -41,8 +40,7 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er
td := &TemplateData{ td := &TemplateData{
Version: version, Version: version,
DBDriver: conf.Driver, Driver: conf.Driver,
DBOpen: conf.OpenStr,
Direction: direction, Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version), Func: fmt.Sprintf("%v_%v", directionStr, version),
} }
@ -76,13 +74,13 @@ package main
import ( import (
"database/sql" "database/sql"
_ "github.com/lib/pq" _ "{{.Driver.Import}}"
"log" "log"
"fmt" "fmt"
) )
func main() { func main() {
db, err := sql.Open("{{.DBDriver}}", "{{.DBOpen}}") db, err := sql.Open("{{.Driver.Name}}", "{{.Driver.OpenStr}}")
if err != nil { if err != nil {
log.Fatal("failed to open DB:", err) log.Fatal("failed to open DB:", err)
} }