mirror of https://github.com/pressly/goose.git
dbconf: split out separate DBDriver struct to encapsulate the info required for a given driver. as a bonus, we now import the correct package for Go migrations based on the driver - previously, we imported postgres only.
parent
161a0a1cc2
commit
4446df2ca6
|
@ -35,7 +35,7 @@ func statusRun(cmd *Command, args ...string) {
|
||||||
log.Fatal(e)
|
log.Fatal(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, e := sql.Open(conf.Driver, conf.OpenStr)
|
db, e := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Fatal("couldn't open DB:", e)
|
log.Fatal("couldn't open DB:", e)
|
||||||
}
|
}
|
||||||
|
|
53
dbconf.go
53
dbconf.go
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/kylelemons/go-gypsy/yaml"
|
"github.com/kylelemons/go-gypsy/yaml"
|
||||||
|
@ -13,11 +14,18 @@ import (
|
||||||
var dbPath = flag.String("path", "db", "folder containing db info")
|
var dbPath = flag.String("path", "db", "folder containing db info")
|
||||||
var dbEnv = flag.String("env", "development", "which DB environment to use")
|
var dbEnv = flag.String("env", "development", "which DB environment to use")
|
||||||
|
|
||||||
|
// DBDriver encapsulates the info needed to work with
|
||||||
|
// a specific database driver
|
||||||
|
type DBDriver struct {
|
||||||
|
Name string
|
||||||
|
OpenStr string
|
||||||
|
Import string
|
||||||
|
}
|
||||||
|
|
||||||
type DBConf struct {
|
type DBConf struct {
|
||||||
MigrationsDir string
|
MigrationsDir string
|
||||||
Env string
|
Env string
|
||||||
Driver string
|
Driver DBDriver
|
||||||
OpenStr string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// default helper - makes a DBConf from the dbPath and dbEnv flags
|
// default helper - makes a DBConf from the dbPath and dbEnv flags
|
||||||
|
@ -55,10 +63,47 @@ func makeDBConfDetails(p, env string) (*DBConf, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d := NewDBDriver(drv, open)
|
||||||
|
|
||||||
|
// XXX: allow an import entry to override DBDriver.Import
|
||||||
|
|
||||||
|
if !d.IsValid() {
|
||||||
|
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
|
||||||
|
}
|
||||||
|
|
||||||
return &DBConf{
|
return &DBConf{
|
||||||
MigrationsDir: filepath.Join(p, "migrations"),
|
MigrationsDir: filepath.Join(p, "migrations"),
|
||||||
Env: env,
|
Env: env,
|
||||||
Driver: drv,
|
Driver: d,
|
||||||
OpenStr: open,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a new DBDriver and populate driver specific
|
||||||
|
// fields for drivers that we know about.
|
||||||
|
// Further customization may be done in NewDBConf
|
||||||
|
func NewDBDriver(name, open string) DBDriver {
|
||||||
|
|
||||||
|
d := DBDriver{
|
||||||
|
Name: name,
|
||||||
|
OpenStr: open,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "postgres":
|
||||||
|
d.Import = "github.com/lib/pq"
|
||||||
|
|
||||||
|
case "mymysql":
|
||||||
|
d.Import = "github.com/ziutek/mymysql/godrv"
|
||||||
|
}
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure we have enough info about this driver
|
||||||
|
func (drv *DBDriver) IsValid() bool {
|
||||||
|
if len(drv.Import) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ func TestBasics(t *testing.T) {
|
||||||
t.Error("couldn't create DBConf")
|
t.Error("couldn't create DBConf")
|
||||||
}
|
}
|
||||||
|
|
||||||
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver, dbconf.OpenStr}
|
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr}
|
||||||
want := []string{"db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
|
want := []string{"db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
|
||||||
|
|
||||||
for i, s := range got {
|
for i, s := range got {
|
||||||
|
|
|
@ -42,7 +42,7 @@ type MigrationMap struct {
|
||||||
|
|
||||||
func runMigrations(conf *DBConf, migrationsDir string, target int64) {
|
func runMigrations(conf *DBConf, migrationsDir string, target int64) {
|
||||||
|
|
||||||
db, err := sql.Open(conf.Driver, conf.OpenStr)
|
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("couldn't open DB:", err)
|
log.Fatal("couldn't open DB:", err)
|
||||||
}
|
}
|
||||||
|
@ -267,7 +267,7 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
|
||||||
// their own DB instance
|
// their own DB instance
|
||||||
func getDBVersion(conf *DBConf) int64 {
|
func getDBVersion(conf *DBConf) int64 {
|
||||||
|
|
||||||
db, err := sql.Open(conf.Driver, conf.OpenStr)
|
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("couldn't open DB:", err)
|
log.Fatal("couldn't open DB:", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,7 @@ import (
|
||||||
|
|
||||||
type TemplateData struct {
|
type TemplateData struct {
|
||||||
Version int64
|
Version int64
|
||||||
DBDriver string
|
Driver DBDriver
|
||||||
DBOpen string
|
|
||||||
Direction bool
|
Direction bool
|
||||||
Func string
|
Func string
|
||||||
}
|
}
|
||||||
|
@ -41,8 +40,7 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er
|
||||||
|
|
||||||
td := &TemplateData{
|
td := &TemplateData{
|
||||||
Version: version,
|
Version: version,
|
||||||
DBDriver: conf.Driver,
|
Driver: conf.Driver,
|
||||||
DBOpen: conf.OpenStr,
|
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Func: fmt.Sprintf("%v_%v", directionStr, version),
|
Func: fmt.Sprintf("%v_%v", directionStr, version),
|
||||||
}
|
}
|
||||||
|
@ -76,13 +74,13 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
_ "github.com/lib/pq"
|
_ "{{.Driver.Import}}"
|
||||||
"log"
|
"log"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
db, err := sql.Open("{{.DBDriver}}", "{{.DBOpen}}")
|
db, err := sql.Open("{{.Driver.Name}}", "{{.Driver.OpenStr}}")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to open DB:", err)
|
log.Fatal("failed to open DB:", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue