mirror of https://github.com/pressly/goose.git
125 lines
2.6 KiB
Go
125 lines
2.6 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"github.com/kylelemons/go-gypsy/yaml"
|
|
"github.com/lib/pq"
|
|
"os"
|
|
"path/filepath"
|
|
)
|
|
|
|
// global options. available to any subcommands.
|
|
var dbPath = flag.String("path", "db", "folder containing db info")
|
|
var dbEnv = flag.String("env", "development", "which DB environment to use")
|
|
|
|
// DBDriver encapsulates the info needed to work with
|
|
// a specific database driver
|
|
type DBDriver struct {
|
|
Name string
|
|
OpenStr string
|
|
Import string
|
|
Dialect SqlDialect
|
|
}
|
|
|
|
type DBConf struct {
|
|
MigrationsDir string
|
|
Env string
|
|
Driver DBDriver
|
|
}
|
|
|
|
// default helper - makes a DBConf from the dbPath and dbEnv flags
|
|
func NewDBConf() (*DBConf, error) {
|
|
return newDBConfDetails(*dbPath, *dbEnv)
|
|
}
|
|
|
|
// extract configuration details from the given file
|
|
func newDBConfDetails(p, env string) (*DBConf, error) {
|
|
|
|
cfgFile := filepath.Join(p, "dbconf.yml")
|
|
|
|
f, err := yaml.ReadFile(cfgFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
drv, err := f.Get(fmt.Sprintf("%s.driver", env))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
open, err := f.Get(fmt.Sprintf("%s.open", env))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
open = os.ExpandEnv(open)
|
|
|
|
// Automatically parse postgres urls
|
|
if drv == "postgres" {
|
|
|
|
// Assumption: If we can parse the URL, we should
|
|
if parsedURL, err := pq.ParseURL(open); err == nil && parsedURL != "" {
|
|
open = parsedURL
|
|
}
|
|
}
|
|
|
|
d := NewDBDriver(drv, open)
|
|
|
|
// allow the configuration to override the Import for this driver
|
|
if imprt, err := f.Get(fmt.Sprintf("%s.import", env)); err == nil {
|
|
d.Import = imprt
|
|
}
|
|
|
|
// allow the configuration to override the Dialect for this driver
|
|
if dialect, err := f.Get(fmt.Sprintf("%s.dialect", env)); err == nil {
|
|
d.Dialect = DialectByName(dialect)
|
|
}
|
|
|
|
if !d.IsValid() {
|
|
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
|
|
}
|
|
|
|
return &DBConf{
|
|
MigrationsDir: filepath.Join(p, "migrations"),
|
|
Env: env,
|
|
Driver: d,
|
|
}, nil
|
|
}
|
|
|
|
// Create a new DBDriver and populate driver specific
|
|
// fields for drivers that we know about.
|
|
// Further customization may be done in NewDBConf
|
|
func NewDBDriver(name, open string) DBDriver {
|
|
|
|
d := DBDriver{
|
|
Name: name,
|
|
OpenStr: open,
|
|
}
|
|
|
|
switch name {
|
|
case "postgres":
|
|
d.Import = "github.com/lib/pq"
|
|
d.Dialect = &PostgresDialect{}
|
|
|
|
case "mymysql":
|
|
d.Import = "github.com/ziutek/mymysql/godrv"
|
|
d.Dialect = &MySqlDialect{}
|
|
}
|
|
|
|
return d
|
|
}
|
|
|
|
// ensure we have enough info about this driver
|
|
func (drv *DBDriver) IsValid() bool {
|
|
if len(drv.Import) == 0 {
|
|
return false
|
|
}
|
|
|
|
if drv.Dialect == nil {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|