Refactor commands

pull/1/head
Vojtech Vitek (V-Teq) 2016-03-03 14:28:08 -05:00
parent 2cccd9df36
commit 59f7a561cb
14 changed files with 132 additions and 528 deletions

View File

@ -1,137 +0,0 @@
package goose
import (
"bytes"
"encoding/gob"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"text/template"
)
type templateData struct {
Version int64
Import string
Conf string // gob encoded DBConf
Direction bool
Func string
InsertStmt string
}
func init() {
gob.Register(PostgresDialect{})
gob.Register(MySqlDialect{})
gob.Register(Sqlite3Dialect{})
}
//
// Run a .go migration.
//
// In order to do this, we copy a modified version of the
// original .go migration, and execute it via `go run` along
// with a main() of our own creation.
//
func runGoMigration(path string, version int64, direction bool) error {
// everything gets written to a temp dir, and zapped afterwards
d, e := ioutil.TempDir("", "goose")
if e != nil {
log.Fatal(e)
}
defer os.RemoveAll(d)
directionStr := "Down"
if direction {
directionStr = "Up"
}
var bb bytes.Buffer
if err := gob.NewEncoder(&bb).Encode(conf); err != nil {
return err
}
// XXX: there must be a better way of making this byte array
// available to the generated code...
// but for now, print an array literal of the gob bytes
var sb bytes.Buffer
sb.WriteString("[]byte{ ")
for _, b := range bb.Bytes() {
sb.WriteString(fmt.Sprintf("0x%02x, ", b))
}
sb.WriteString("}")
td := &templateData{
Version: version,
Import: conf.Driver.Import,
Conf: sb.String(),
Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version),
InsertStmt: conf.Driver.Dialect.insertVersionSql(),
}
main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationDriverTemplate, td)
if e != nil {
log.Fatal(e)
}
outpath := filepath.Join(d, filepath.Base(path))
if _, e = copyFile(outpath, path); e != nil {
log.Fatal(e)
}
cmd := exec.Command("go", "run", main, outpath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if e = cmd.Run(); e != nil {
log.Fatal("`go run` failed: ", e)
}
return nil
}
//
// template for the main entry point to a go-based migration.
// this gets linked against the substituted versions of the user-supplied
// scripts in order to execute a migration via `go run`
//
var goMigrationDriverTemplate = template.Must(template.New("goose.go-driver").Parse(`
package main
import (
"log"
"bytes"
"encoding/gob"
_ "{{.Import}}"
"github.com/pressly/goose"
)
func main() {
var conf goose.DBConf
buf := bytes.NewBuffer({{ .Conf }})
if err := gob.NewDecoder(buf).Decode(&conf); err != nil {
log.Fatal("gob.Decode - ", err)
}
db, err := goose.OpenDBFromDBConf(&conf)
if err != nil {
log.Fatal("failed to open DB:", err)
}
defer db.Close()
txn, err := db.Begin()
if err != nil {
log.Fatal("db.Begin:", err)
}
{{ .Func }}(txn)
err = goose.FinalizeMigration(&conf, txn, {{ .Direction }}, {{ .Version }})
if err != nil {
log.Fatal("Commit() failed:", err)
}
}
`))

View File

@ -1,27 +0,0 @@
package main
import (
"flag"
)
// shamelessly snagged from the go tool
// each command gets its own set of args,
// defines its own entry point, and provides its own help
type Command struct {
Run func(cmd *Command, args ...string)
Flag flag.FlagSet
Name string
Usage string
Summary string
Help string
}
func (c *Command) Exec(args []string) {
c.Flag.Usage = func() {
// helpFunc(c, c.Name)
}
c.Flag.Parse(args)
c.Run(c, c.Flag.Args()...)
}

View File

@ -1,51 +0,0 @@
package main
import (
"github.com/pressly/goose"
"fmt"
"log"
"os"
"path/filepath"
"time"
)
var createCmd = &Command{
Name: "create",
Usage: "",
Summary: "Create the scaffolding for a new migration",
Help: `create extended help here...`,
Run: createRun,
}
func createRun(cmd *Command, args ...string) {
if len(args) < 1 {
log.Fatal("goose create: migration name required")
}
migrationType := "go" // default to Go migrations
if len(args) >= 2 {
migrationType = args[1]
}
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
if err = os.MkdirAll(conf.MigrationsDir, 0777); err != nil {
log.Fatal(err)
}
n, err := goose.CreateMigration(args[0], migrationType, conf.MigrationsDir, time.Now())
if err != nil {
log.Fatal(err)
}
a, e := filepath.Abs(n)
if e != nil {
log.Fatal(e)
}
fmt.Println("goose: created", a)
}

View File

@ -1,29 +0,0 @@
package main
import (
"github.com/pressly/goose"
"fmt"
"log"
)
var dbVersionCmd = &Command{
Name: "dbversion",
Usage: "",
Summary: "Print the current version of the database",
Help: `dbversion extended help here...`,
Run: dbVersionRun,
}
func dbVersionRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
current, err := goose.GetDBVersion(conf)
if err != nil {
log.Fatal(err)
}
fmt.Printf("goose: dbversion %v\n", current)
}

View File

@ -1,36 +0,0 @@
package main
import (
"github.com/pressly/goose"
"log"
)
var downCmd = &Command{
Name: "down",
Usage: "",
Summary: "Roll back the version by 1",
Help: `down extended help here...`,
Run: downRun,
}
func downRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
current, err := goose.GetDBVersion(conf)
if err != nil {
log.Fatal(err)
}
previous, err := goose.GetPreviousDBVersion(conf.MigrationsDir, current)
if err != nil {
log.Fatal(err)
}
if err = goose.RunMigrations(conf, conf.MigrationsDir, previous); err != nil {
log.Fatal(err)
}
}

View File

@ -1,39 +0,0 @@
package main
import (
"github.com/pressly/goose"
"log"
)
var redoCmd = &Command{
Name: "redo",
Usage: "",
Summary: "Re-run the latest migration",
Help: `redo extended help here...`,
Run: redoRun,
}
func redoRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
current, err := goose.GetDBVersion(conf)
if err != nil {
log.Fatal(err)
}
previous, err := goose.GetPreviousDBVersion(conf.MigrationsDir, current)
if err != nil {
log.Fatal(err)
}
if err := goose.RunMigrations(conf, conf.MigrationsDir, previous); err != nil {
log.Fatal(err)
}
if err := goose.RunMigrations(conf, conf.MigrationsDir, current); err != nil {
log.Fatal(err)
}
}

View File

@ -1,77 +0,0 @@
package main
import (
"github.com/pressly/goose"
"database/sql"
"fmt"
"log"
"path/filepath"
"time"
)
var statusCmd = &Command{
Name: "status",
Usage: "",
Summary: "dump the migration status for the current DB",
Help: `status extended help here...`,
Run: statusRun,
}
type StatusData struct {
Source string
Status string
}
func statusRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
// collect all migrations
min := int64(0)
max := int64((1 << 63) - 1)
migrations, e := goose.CollectMigrations(conf.MigrationsDir, min, max)
if e != nil {
log.Fatal(e)
}
db, e := goose.OpenDBFromDBConf(conf)
if e != nil {
log.Fatal("couldn't open DB:", e)
}
defer db.Close()
// must ensure that the version table exists if we're running on a pristine DB
if _, e := goose.EnsureDBVersion(conf, db); e != nil {
log.Fatal(e)
}
fmt.Printf("goose: status for environment '%v'\n", conf.Env)
fmt.Println(" Applied At Migration")
fmt.Println(" =======================================")
for _, m := range migrations {
printMigrationStatus(db, m.Version, filepath.Base(m.Source))
}
}
func printMigrationStatus(db *sql.DB, version int64, script string) {
var row goose.MigrationRecord
q := fmt.Sprintf("SELECT tstamp, is_applied FROM goose_db_version WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version)
e := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied)
if e != nil && e != sql.ErrNoRows {
log.Fatal(e)
}
var appliedAt string
if row.IsApplied {
appliedAt = row.TStamp.Format(time.ANSIC)
} else {
appliedAt = "Pending"
}
fmt.Printf(" %-24s -- %v\n", appliedAt, script)
}

View File

@ -1,32 +0,0 @@
package main
import (
"log"
"github.com/pressly/goose"
)
var upCmd = &Command{
Name: "up",
Usage: "",
Summary: "Migrate the DB to the most recent version available",
Help: `up extended help here...`,
Run: upRun,
}
func upRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
if err != nil {
log.Fatal(err)
}
if err := goose.RunMigrations(conf, conf.MigrationsDir, target); err != nil {
log.Fatal(err)
}
}

View File

@ -1,79 +1,90 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"strings"
"text/template"
"github.com/pressly/goose/lib/goose"
"github.com/pressly/goose"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
)
// global options. available to any subcommands.
var flagPath = flag.String("path", "db", "folder containing db info")
var flagEnv = flag.String("env", "development", "which DB environment to use")
var flagPgSchema = flag.String("pgschema", "", "which postgres-schema to migrate (default = none)")
// helper to create a DBConf from the given flags
func dbConfFromFlags() (dbconf *goose.DBConf, err error) {
return goose.NewDBConf(*flagPath, *flagEnv, *flagPgSchema)
}
var commands = []*Command{
upCmd,
downCmd,
redoCmd,
statusCmd,
createCmd,
dbVersionCmd,
}
var (
flags = flag.NewFlagSet("goose", flag.ExitOnError)
dir = flags.String("dir", ".", "directory with migration files")
)
func main() {
flags.Usage = usage
flags.Parse(os.Args[1:])
flag.Usage = usage
flag.Parse()
args := flag.Args()
if len(args) == 0 || args[0] == "-h" {
flag.Usage()
args := flags.Args()
if len(args) != 3 {
flags.Usage()
return
}
var cmd *Command
name := args[0]
for _, c := range commands {
if strings.HasPrefix(c.Name, name) {
cmd = c
break
if args[0] == "-h" || args[0] == "--help" {
flags.Usage()
return
}
driver, dbstring, command := args[0], args[1], args[2]
switch driver {
case "postgres", "mysql", "sqlite3":
if err := goose.SetDialect(driver); err != nil {
log.Fatal(err)
}
default:
log.Fatalf("%q driver not supported\n", driver)
}
if cmd == nil {
fmt.Printf("error: unknown command %q\n", name)
flag.Usage()
os.Exit(1)
switch dbstring {
case "":
log.Fatalf("-dbstring=%q not supported\n", dbstring)
default:
}
cmd.Exec(args[1:])
db, err := sql.Open(driver, dbstring)
if err != nil {
log.Fatalf("-dbstring=%q: %v\n", dbstring, err)
}
if err := goose.Run(command, db, *dir); err != nil {
log.Fatalf("goose run: %v", err)
}
}
func usage() {
fmt.Print(usagePrefix)
flag.PrintDefaults()
usageTmpl.Execute(os.Stdout, commands)
flags.PrintDefaults()
fmt.Print(usageCommands)
}
var usagePrefix = `
goose is a database migration management system for Go projects.
var (
usagePrefix = `Usage: goose [OPTIONS] DRIVER DBSTRING COMMAND
Usage:
goose [options] <subcommand> [subcommand options]
Examples:
goose postgres "user=postgres dbname=postgres sslmode=disable" up
goose mysql "user:password@/dbname" down
goose sqlite3 ./foo.db status
Options:
`
var usageTmpl = template.Must(template.New("usage").Parse(
`
Commands:{{range .}}
{{.Name | printf "%-10s"}} {{.Summary}}{{end}}
`))
usageCommands = `
Commands:
up Migrate the DB to the most recent version available
down Roll back the version by 1
redo Re-run the latest migration
status Dump the migration status for the current DB
dbversion Print the current version of the database
`
)

View File

@ -2,6 +2,8 @@ package goose
import (
"database/sql"
"fmt"
"github.com/mattn/go-sqlite3"
)
@ -13,15 +15,22 @@ type SqlDialect interface {
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
}
// drivers that we don't know about can ask for a dialect by name
func dialectByName(d string) SqlDialect {
var dialect SqlDialect = &PostgresDialect{}
func GetDialect() SqlDialect {
return dialect
}
func SetDialect(d string) error {
switch d {
case "postgres":
return &PostgresDialect{}
dialect = &PostgresDialect{}
case "mysql":
return &MySqlDialect{}
dialect = &MySqlDialect{}
case "sqlite3":
return &Sqlite3Dialect{}
dialect = &Sqlite3Dialect{}
default:
return fmt.Errorf("%q: unknown dialect", d)
}
return nil

View File

@ -1,14 +0,0 @@
package main
import (
"database/sql"
"fmt"
)
func Up_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Up!")
}
func Down_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Down!")
}

34
goose.go Normal file
View File

@ -0,0 +1,34 @@
package goose
import (
"database/sql"
"fmt"
)
func Run(command string, db *sql.DB, dir string) error {
switch command {
case "up":
if err := Up(db, dir); err != nil {
return err
}
case "down":
if err := Down(db, dir); err != nil {
return err
}
case "redo":
if err := Redo(db, dir); err != nil {
return err
}
case "status":
if err := Status(db, dir); err != nil {
return err
}
case "version":
if err := Version(db, dir); err != nil {
return err
}
default:
return fmt.Errorf("%q: no such command", command)
}
return nil
}

View File

@ -63,17 +63,17 @@ func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
direction := current < target
ms.Sort(direction)
fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n", current, target)
fmt.Printf("goose: migrating db, current version: %d, target: %d\n", current, target)
for _, m := range ms {
switch filepath.Ext(m.Source) {
// case ".go":
// err = runGoMigration(m.Source, m.Version, direction)
case ".sql":
err = runSQLMigration(db, m.Source, m.Version, direction)
default:
continue
}
err = runSQLMigration(db, m.Source, m.Version, direction)
if err != nil {
return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err))
}
@ -176,7 +176,7 @@ func NumericComponent(name string) (int64, error) {
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
rows, err := dialectByName("postgres").dbVersionQuery(db)
rows, err := GetDialect().dbVersionQuery(db)
if err != nil {
if err == ErrTableDoesNotExist {
return 0, createVersionTable(db)
@ -230,7 +230,7 @@ func createVersionTable(db *sql.DB) error {
return err
}
d := dialectByName("postgres")
d := GetDialect()
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback()
@ -332,13 +332,7 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string,
filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType)
fpath := filepath.Join(dir, filename)
var tmpl *template.Template
if migrationType == "sql" {
tmpl = sqlMigrationTemplate
} else {
tmpl = goMigrationTemplate
}
tmpl := sqlMigrationTemplate
path, err = writeTemplateToFile(fpath, tmpl, timestamp)
@ -350,7 +344,7 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string,
func FinalizeMigration(txn *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := dialectByName("postgres").insertVersionSql()
stmt := GetDialect().insertVersionSql()
if _, err := txn.Exec(stmt, v, direction); err != nil {
txn.Rollback()
return err
@ -359,24 +353,6 @@ func FinalizeMigration(txn *sql.Tx, direction bool, v int64) error {
return txn.Commit()
}
var goMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
package main
import (
"database/sql"
)
// Up is executed when this migration is applied
func Up_{{ . }}(txn *sql.Tx) {
}
// Down is executed when this migration is rolled back
func Down_{{ . }}(txn *sql.Tx) {
}
`))
var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`
-- +goose Up
-- SQL in section 'Up' is executed when this migration is applied

16
version.go Normal file
View File

@ -0,0 +1,16 @@
package goose
import (
"database/sql"
"fmt"
)
func Version(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
if err != nil {
return err
}
fmt.Printf("goose: dbversion %v\n", current)
return nil
}