mirror of https://github.com/pressly/goose.git
dbconf: split out separate DBDriver struct to encapsulate the info required for a given driver. as a bonus, we now import the correct package for Go migrations based on the driver - previously, we imported postgres only.
parent
161a0a1cc2
commit
4446df2ca6
|
@ -35,7 +35,7 @@ func statusRun(cmd *Command, args ...string) {
|
|||
log.Fatal(e)
|
||||
}
|
||||
|
||||
db, e := sql.Open(conf.Driver, conf.OpenStr)
|
||||
db, e := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||
if e != nil {
|
||||
log.Fatal("couldn't open DB:", e)
|
||||
}
|
||||
|
|
53
dbconf.go
53
dbconf.go
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/kylelemons/go-gypsy/yaml"
|
||||
|
@ -13,11 +14,18 @@ import (
|
|||
var dbPath = flag.String("path", "db", "folder containing db info")
|
||||
var dbEnv = flag.String("env", "development", "which DB environment to use")
|
||||
|
||||
// DBDriver encapsulates the info needed to work with
|
||||
// a specific database driver
|
||||
type DBDriver struct {
|
||||
Name string
|
||||
OpenStr string
|
||||
Import string
|
||||
}
|
||||
|
||||
type DBConf struct {
|
||||
MigrationsDir string
|
||||
Env string
|
||||
Driver string
|
||||
OpenStr string
|
||||
Driver DBDriver
|
||||
}
|
||||
|
||||
// default helper - makes a DBConf from the dbPath and dbEnv flags
|
||||
|
@ -55,10 +63,47 @@ func makeDBConfDetails(p, env string) (*DBConf, error) {
|
|||
}
|
||||
}
|
||||
|
||||
d := NewDBDriver(drv, open)
|
||||
|
||||
// XXX: allow an import entry to override DBDriver.Import
|
||||
|
||||
if !d.IsValid() {
|
||||
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
|
||||
}
|
||||
|
||||
return &DBConf{
|
||||
MigrationsDir: filepath.Join(p, "migrations"),
|
||||
Env: env,
|
||||
Driver: drv,
|
||||
OpenStr: open,
|
||||
Driver: d,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create a new DBDriver and populate driver specific
|
||||
// fields for drivers that we know about.
|
||||
// Further customization may be done in NewDBConf
|
||||
func NewDBDriver(name, open string) DBDriver {
|
||||
|
||||
d := DBDriver{
|
||||
Name: name,
|
||||
OpenStr: open,
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "postgres":
|
||||
d.Import = "github.com/lib/pq"
|
||||
|
||||
case "mymysql":
|
||||
d.Import = "github.com/ziutek/mymysql/godrv"
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// ensure we have enough info about this driver
|
||||
func (drv *DBDriver) IsValid() bool {
|
||||
if len(drv.Import) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ func TestBasics(t *testing.T) {
|
|||
t.Error("couldn't create DBConf")
|
||||
}
|
||||
|
||||
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver, dbconf.OpenStr}
|
||||
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr}
|
||||
want := []string{"db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
|
||||
|
||||
for i, s := range got {
|
||||
|
|
|
@ -42,7 +42,7 @@ type MigrationMap struct {
|
|||
|
||||
func runMigrations(conf *DBConf, migrationsDir string, target int64) {
|
||||
|
||||
db, err := sql.Open(conf.Driver, conf.OpenStr)
|
||||
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||
if err != nil {
|
||||
log.Fatal("couldn't open DB:", err)
|
||||
}
|
||||
|
@ -267,7 +267,7 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
|
|||
// their own DB instance
|
||||
func getDBVersion(conf *DBConf) int64 {
|
||||
|
||||
db, err := sql.Open(conf.Driver, conf.OpenStr)
|
||||
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||
if err != nil {
|
||||
log.Fatal("couldn't open DB:", err)
|
||||
}
|
||||
|
|
|
@ -12,8 +12,7 @@ import (
|
|||
|
||||
type TemplateData struct {
|
||||
Version int64
|
||||
DBDriver string
|
||||
DBOpen string
|
||||
Driver DBDriver
|
||||
Direction bool
|
||||
Func string
|
||||
}
|
||||
|
@ -41,8 +40,7 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er
|
|||
|
||||
td := &TemplateData{
|
||||
Version: version,
|
||||
DBDriver: conf.Driver,
|
||||
DBOpen: conf.OpenStr,
|
||||
Driver: conf.Driver,
|
||||
Direction: direction,
|
||||
Func: fmt.Sprintf("%v_%v", directionStr, version),
|
||||
}
|
||||
|
@ -76,13 +74,13 @@ package main
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "{{.Driver.Import}}"
|
||||
"log"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("{{.DBDriver}}", "{{.DBOpen}}")
|
||||
db, err := sql.Open("{{.Driver.Name}}", "{{.Driver.OpenStr}}")
|
||||
if err != nil {
|
||||
log.Fatal("failed to open DB:", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue