diff --git a/cmd.go b/cmd.go new file mode 100644 index 0000000..88fd97c --- /dev/null +++ b/cmd.go @@ -0,0 +1,35 @@ +package main + +import ( + "flag" + // "fmt" + "os" +) + +// shamelessly snagged from the go tool +// each command gets its own set of args, +// defines its own entry point, and provides its own help +type Command struct { + Run func(cmd *Command, args ...string) + Flag flag.FlagSet + + Name string + Usage string + + Summary string + Help string +} + +func (c *Command) Exec(args []string) { + c.Flag.Usage = func() { + // helpFunc(c, c.Name) + } + c.Flag.Parse(args) + defer func() { + if r := recover(); r != nil { + panic(r) + } + os.Exit(1) + }() + c.Run(c, c.Flag.Args()...) +} diff --git a/dbconf.go b/dbconf.go new file mode 100644 index 0000000..5cac7e6 --- /dev/null +++ b/dbconf.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "github.com/kylelemons/go-gypsy/yaml" +) + +type DBConf struct { + Name string + Driver string + OpenStr string +} + +// extract configuration details from the given file +func dbConfFromFile(path, envtype string) (*DBConf, error) { + + f, err := yaml.ReadFile(path) + if err != nil { + return nil, err + } + + drv, derr := f.Get(fmt.Sprintf("%s.driver", envtype)) + if derr != nil { + return nil, derr + } + + open, oerr := f.Get(fmt.Sprintf("%s.open", envtype)) + if oerr != nil { + return nil, oerr + } + + return &DBConf{ + Name: envtype, + Driver: drv, + OpenStr: open, + }, nil +} diff --git a/main.go b/main.go index 8a9fbd8..5f7f246 100644 --- a/main.go +++ b/main.go @@ -3,53 +3,39 @@ package main import ( "flag" "fmt" - "github.com/kylelemons/go-gypsy/yaml" - "log" - "path" + "os" + "strings" ) -type DBConf struct { - Name string - Driver string - OpenStr string +var commands = []*Command{ + upCmd, } -var dbFolder = flag.String("db", "db", "folder containing db info") -var dbConfName = flag.String("config", "development", "which DB configuration to use") -var targetVersion = flag.Int("target", -1, "which DB version to target (defaults to latest version)") - func main() { + + // XXX: create a flag.Usage that dumps all commands flag.Parse() - conf, err := dbConfFromFile(path.Join(*dbFolder, "dbconf.yml"), *dbConfName) - if err != nil { - log.Fatal(err) + args := flag.Args() + if len(args) == 0 { + flag.Usage() + return } - runMigrations(conf, path.Join(*dbFolder, "migrations"), *targetVersion) -} - -// extract configuration details from the given file -func dbConfFromFile(path, envtype string) (*DBConf, error) { - - f, err := yaml.ReadFile(path) - if err != nil { - return nil, err - } - - drv, derr := f.Get(fmt.Sprintf("%s.driver", envtype)) - if derr != nil { - return nil, derr - } - - open, oerr := f.Get(fmt.Sprintf("%s.open", envtype)) - if oerr != nil { - return nil, oerr - } - - return &DBConf{ - Name: envtype, - Driver: drv, - OpenStr: open, - }, nil + var cmd *Command + name := args[0] + for _, c := range commands { + if strings.HasPrefix(c.Name, name) { + cmd = c + break + } + } + + if cmd == nil { + fmt.Printf("error: unknown command %q\n", name) + flag.Usage() + os.Exit(1) + } + + cmd.Exec(args[1:]) } diff --git a/migrate.go b/migrate.go index 81db657..c4b3292 100644 --- a/migrate.go +++ b/migrate.go @@ -41,7 +41,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int) { current, e := ensureDBVersion(db) if e != nil { - log.Fatal("couldn't get/set DB version") + log.Fatalf("couldn't get DB version: %v", e) } mm, err := collectMigrations(migrationsDir, current, target) @@ -97,12 +97,6 @@ func collectMigrations(dirpath string, current, target int) (mm *MigrationMap, e Migrations: make(map[int]Migration), } - // if target is the default -1, - // we need to find the most recent possible version to target - if target < 0 { - target = mostRecentVersionAvailable(names) - } - // extract the numeric component of each migration, // filter out any uninteresting files, // and ensure we only have one file per migration version. @@ -134,31 +128,6 @@ func collectMigrations(dirpath string, current, target int) (mm *MigrationMap, e return mm, nil } -// helper to identify the most recent possible version -// within a folder of migration scripts -func mostRecentVersionAvailable(names []string) int { - - mostRecent := -1 - - for _, name := range names { - - if ext := path.Ext(name); ext != ".go" && ext != ".sql" { - continue - } - - v, e := numericComponent(name) - if e != nil { - continue - } - - if v > mostRecent { - mostRecent = v - } - } - - return mostRecent -} - func versionFilter(v, current, target int) bool { // special case - default target value diff --git a/up.go b/up.go new file mode 100644 index 0000000..90c4462 --- /dev/null +++ b/up.go @@ -0,0 +1,68 @@ +package main + +import ( + "log" + "os" + "path" +) + +var upCmd = &Command{ + Name: "up", + Usage: "", + Summary: "Migrate the DB to the most recent version available", + Help: `up extended help here...`, +} + +var dbFolder = upCmd.Flag.String("db", "db", "folder containing db info") +var dbConfName = upCmd.Flag.String("config", "development", "which DB configuration to use") + +func upRun(cmd *Command, args ...string) { + + conf, err := dbConfFromFile(path.Join(*dbFolder, "dbconf.yml"), *dbConfName) + if err != nil { + log.Fatal(err) + } + + folder := path.Join(*dbFolder, "migrations") + target := mostRecentVersionAvailable(folder) + runMigrations(conf, folder, target) +} + +// helper to identify the most recent possible version +// within a folder of migration scripts +func mostRecentVersionAvailable(dirpath string) int { + + dir, err := os.Open(dirpath) + if err != nil { + log.Fatal(err) + } + + names, err := dir.Readdirnames(0) + if err != nil { + log.Fatal(err) + } + + mostRecent := -1 + + for _, name := range names { + + if ext := path.Ext(name); ext != ".go" && ext != ".sql" { + continue + } + + v, e := numericComponent(name) + if e != nil { + continue + } + + if v > mostRecent { + mostRecent = v + } + } + + return mostRecent +} + +func init() { + upCmd.Run = upRun +}