From 2cccd9df36bf8db0270e1767ae01f07828fa82e5 Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (V-Teq)" Date: Wed, 2 Mar 2016 17:23:15 -0500 Subject: [PATCH] Refactor goose pkg --- migration_go.go => _migration_go.go | 2 +- db-sample/dbconf.yml | 22 --- .../migrations/20130106222315_and_again.go | 15 -- dbconf.go | 139 ------------------ dbconf_test.go | 70 --------- down.go | 23 +++ .../migrations/00001_create_post_table.sql | 1 - .../migrations/00002_next.sql | 1 - example/migrations/00003_go_migration.go | 14 ++ migrate.go | 57 ++----- migration_sql.go | 4 +- redo.go | 27 ++++ status.go | 53 +++++++ up.go | 17 +++ 14 files changed, 153 insertions(+), 292 deletions(-) rename migration_go.go => _migration_go.go (97%) delete mode 100644 db-sample/dbconf.yml delete mode 100644 db-sample/migrations/20130106222315_and_again.go delete mode 100644 dbconf.go delete mode 100644 dbconf_test.go create mode 100644 down.go rename db-sample/migrations/001_basics.sql => example/migrations/00001_create_post_table.sql (99%) rename db-sample/migrations/002_next.sql => example/migrations/00002_next.sql (99%) create mode 100644 example/migrations/00003_go_migration.go create mode 100644 redo.go create mode 100644 status.go create mode 100644 up.go diff --git a/migration_go.go b/_migration_go.go similarity index 97% rename from migration_go.go rename to _migration_go.go index ed13c2c..032cfd2 100644 --- a/migration_go.go +++ b/_migration_go.go @@ -34,7 +34,7 @@ func init() { // original .go migration, and execute it via `go run` along // with a main() of our own creation. // -func runGoMigration(conf *DBConf, path string, version int64, direction bool) error { +func runGoMigration(path string, version int64, direction bool) error { // everything gets written to a temp dir, and zapped afterwards d, e := ioutil.TempDir("", "goose") diff --git a/db-sample/dbconf.yml b/db-sample/dbconf.yml deleted file mode 100644 index b820527..0000000 --- a/db-sample/dbconf.yml +++ /dev/null @@ -1,22 +0,0 @@ - -test: - driver: postgres - open: user=liam dbname=tester sslmode=disable - -development: - driver: postgres - open: user=liam dbname=tester sslmode=disable - -production: - driver: postgres - open: user=liam dbname=tester sslmode=verify-full - -customimport: - driver: customdriver - open: customdriver open - import: github.com/custom/driver - dialect: mysql - -environment_variable_config: - driver: $DB_DRIVER - open: $DATABASE_URL diff --git a/db-sample/migrations/20130106222315_and_again.go b/db-sample/migrations/20130106222315_and_again.go deleted file mode 100644 index 1aac8ba..0000000 --- a/db-sample/migrations/20130106222315_and_again.go +++ /dev/null @@ -1,15 +0,0 @@ - -package main - -import ( - "database/sql" - "fmt" -) - -func Up_20130106222315(txn *sql.Tx) { - fmt.Println("Hello from migration 20130106222315 Up!") -} - -func Down_20130106222315(txn *sql.Tx) { - fmt.Println("Hello from migration 20130106222315 Down!") -} diff --git a/dbconf.go b/dbconf.go deleted file mode 100644 index 691a847..0000000 --- a/dbconf.go +++ /dev/null @@ -1,139 +0,0 @@ -package goose - -import ( - "database/sql" - "errors" - "fmt" - "os" - "path/filepath" - - "github.com/kylelemons/go-gypsy/yaml" - "github.com/lib/pq" -) - -// DBDriver encapsulates the info needed to work with -// a specific database driver -type DBDriver struct { - Name string - OpenStr string - Import string - Dialect SqlDialect -} - -type DBConf struct { - MigrationsDir string - Env string - Driver DBDriver - PgSchema string -} - -// extract configuration details from the given file -func NewDBConf(p, env string, pgschema string) (*DBConf, error) { - - cfgFile := filepath.Join(p, "dbconf.yml") - - f, err := yaml.ReadFile(cfgFile) - if err != nil { - return nil, err - } - - drv, err := f.Get(fmt.Sprintf("%s.driver", env)) - if err != nil { - return nil, err - } - drv = os.ExpandEnv(drv) - - open, err := f.Get(fmt.Sprintf("%s.open", env)) - if err != nil { - return nil, err - } - open = os.ExpandEnv(open) - - // Automatically parse postgres urls - if drv == "postgres" { - - // Assumption: If we can parse the URL, we should - if parsedURL, err := pq.ParseURL(open); err == nil && parsedURL != "" { - open = parsedURL - } - } - - d := newDBDriver(drv, open) - - // allow the configuration to override the Import for this driver - if imprt, err := f.Get(fmt.Sprintf("%s.import", env)); err == nil { - d.Import = imprt - } - - // allow the configuration to override the Dialect for this driver - if dialect, err := f.Get(fmt.Sprintf("%s.dialect", env)); err == nil { - d.Dialect = dialectByName(dialect) - } - - if !d.IsValid() { - return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d)) - } - - return &DBConf{ - MigrationsDir: filepath.Join(p, "migrations"), - Env: env, - Driver: d, - PgSchema: pgschema, - }, nil -} - -// Create a new DBDriver and populate driver specific -// fields for drivers that we know about. -// Further customization may be done in NewDBConf -func newDBDriver(name, open string) DBDriver { - - d := DBDriver{ - Name: name, - OpenStr: open, - } - - switch name { - case "postgres": - d.Import = "github.com/lib/pq" - d.Dialect = &PostgresDialect{} - - case "mymysql": - d.Import = "github.com/ziutek/mymysql/godrv" - d.Dialect = &MySqlDialect{} - - case "mysql": - d.Import = "github.com/go-sql-driver/mysql" - d.Dialect = &MySqlDialect{} - - case "sqlite3": - d.Import = "github.com/mattn/go-sqlite3" - d.Dialect = &Sqlite3Dialect{} - } - - return d -} - -// ensure we have enough info about this driver -func (drv *DBDriver) IsValid() bool { - return len(drv.Import) > 0 && drv.Dialect != nil -} - -// OpenDBFromDBConf wraps database/sql.DB.Open() and configures -// the newly opened DB based on the given DBConf. -// -// Callers must Close() the returned DB. -func OpenDBFromDBConf(conf *DBConf) (*sql.DB, error) { - db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr) - if err != nil { - return nil, err - } - - // if a postgres schema has been specified, apply it - if conf.Driver.Name == "postgres" && conf.PgSchema != "" { - if _, err := db.Exec("SET search_path TO " + conf.PgSchema); err != nil { - return nil, err - } - } - - return db, nil -} diff --git a/dbconf_test.go b/dbconf_test.go deleted file mode 100644 index 37828ae..0000000 --- a/dbconf_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package goose - -import ( - "os" - "reflect" - "testing" -) - -func TestBasics(t *testing.T) { - - dbconf, err := NewDBConf("../../db-sample", "test", "") - if err != nil { - t.Fatal(err) - } - - got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr} - want := []string{"../../db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"} - - for i, s := range got { - if s != want[i] { - t.Errorf("Unexpected DBConf value. got %v, want %v", s, want[i]) - } - } -} - -func TestImportOverride(t *testing.T) { - - dbconf, err := NewDBConf("../../db-sample", "customimport", "") - if err != nil { - t.Fatal(err) - } - - got := dbconf.Driver.Import - want := "github.com/custom/driver" - if got != want { - t.Errorf("bad custom import. got %v want %v", got, want) - } -} - -func TestDriverSetFromEnvironmentVariable(t *testing.T) { - - databaseUrlEnvVariableKey := "DB_DRIVER" - databaseUrlEnvVariableVal := "sqlite3" - databaseOpenStringKey := "DATABASE_URL" - databaseOpenStringVal := "db.db" - - os.Setenv(databaseUrlEnvVariableKey, databaseUrlEnvVariableVal) - os.Setenv(databaseOpenStringKey, databaseOpenStringVal) - - dbconf, err := NewDBConf("../../db-sample", "environment_variable_config", "") - if err != nil { - t.Fatal(err) - } - - got := reflect.TypeOf(dbconf.Driver.Dialect) - want := reflect.TypeOf(&Sqlite3Dialect{}) - - if got != want { - t.Errorf("Not able to read the driver type from environment variable."+ - "got %v want %v", got, want) - } - - gotOpenString := dbconf.Driver.OpenStr - wantOpenString := databaseOpenStringVal - - if gotOpenString != wantOpenString { - t.Errorf("Not able to read the open string from the environment."+ - "got %v want %v", gotOpenString, wantOpenString) - } -} diff --git a/down.go b/down.go new file mode 100644 index 0000000..43a9f9d --- /dev/null +++ b/down.go @@ -0,0 +1,23 @@ +package goose + +import ( + "database/sql" +) + +func Down(db *sql.DB, dir string) error { + current, err := GetDBVersion(db) + if err != nil { + return err + } + + previous, err := GetPreviousDBVersion(dir, current) + if err != nil { + return err + } + + if err = RunMigrations(db, dir, previous); err != nil { + return err + } + + return nil +} diff --git a/db-sample/migrations/001_basics.sql b/example/migrations/00001_create_post_table.sql similarity index 99% rename from db-sample/migrations/001_basics.sql rename to example/migrations/00001_create_post_table.sql index 2a5bb57..b962691 100644 --- a/db-sample/migrations/001_basics.sql +++ b/example/migrations/00001_create_post_table.sql @@ -1,4 +1,3 @@ - -- +goose Up CREATE TABLE post ( id int NOT NULL, diff --git a/db-sample/migrations/002_next.sql b/example/migrations/00002_next.sql similarity index 99% rename from db-sample/migrations/002_next.sql rename to example/migrations/00002_next.sql index 9f9be33..4cea847 100644 --- a/db-sample/migrations/002_next.sql +++ b/example/migrations/00002_next.sql @@ -1,4 +1,3 @@ - -- +goose Up CREATE TABLE fancier_post ( id int NOT NULL, diff --git a/example/migrations/00003_go_migration.go b/example/migrations/00003_go_migration.go new file mode 100644 index 0000000..21c2089 --- /dev/null +++ b/example/migrations/00003_go_migration.go @@ -0,0 +1,14 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func Up_20130106222315(txn *sql.Tx) { + fmt.Println("Hello from migration 20130106222315 Up!") +} + +func Down_20130106222315(txn *sql.Tx) { + fmt.Println("Hello from migration 20130106222315 Down!") +} diff --git a/migrate.go b/migrate.go index df1ce2d..20af96c 100644 --- a/migrate.go +++ b/migrate.go @@ -12,11 +12,6 @@ import ( "strings" "text/template" "time" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - _ "github.com/ziutek/mymysql/godrv" ) var ( @@ -48,25 +43,13 @@ func newMigration(v int64, src string) *Migration { return &Migration{v, -1, -1, src} } -func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) { - - db, err := OpenDBFromDBConf(conf) - if err != nil { - return err - } - defer db.Close() - - return RunMigrationsOnDb(conf, migrationsDir, target, db) -} - -// Runs migration on a specific database instance. -func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql.DB) (err error) { - current, err := EnsureDBVersion(conf, db) +func RunMigrations(db *sql.DB, dir string, target int64) (err error) { + current, err := EnsureDBVersion(db) if err != nil { return err } - migrations, err := CollectMigrations(migrationsDir, current, target) + migrations, err := CollectMigrations(dir, current, target) if err != nil { return err } @@ -80,16 +63,15 @@ func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql direction := current < target ms.Sort(direction) - fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n", - conf.Env, current, target) + fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n", current, target) for _, m := range ms { switch filepath.Ext(m.Source) { - case ".go": - err = runGoMigration(conf, m.Source, m.Version, direction) + // case ".go": + // err = runGoMigration(m.Source, m.Version, direction) case ".sql": - err = runSQLMigration(conf, db, m.Source, m.Version, direction) + err = runSQLMigration(db, m.Source, m.Version, direction) } if err != nil { @@ -192,12 +174,12 @@ func NumericComponent(name string) (int64, error) { // retrieve the current version for this DB. // Create and initialize the DB version table if it doesn't exist. -func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) { +func EnsureDBVersion(db *sql.DB) (int64, error) { - rows, err := conf.Driver.Dialect.dbVersionQuery(db) + rows, err := dialectByName("postgres").dbVersionQuery(db) if err != nil { if err == ErrTableDoesNotExist { - return 0, createVersionTable(conf, db) + return 0, createVersionTable(db) } return 0, err } @@ -242,13 +224,13 @@ func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) { // Create the goose_db_version table // and insert the initial 0 value into it -func createVersionTable(conf *DBConf, db *sql.DB) error { +func createVersionTable(db *sql.DB) error { txn, err := db.Begin() if err != nil { return err } - d := conf.Driver.Dialect + d := dialectByName("postgres") if _, err := txn.Exec(d.createVersionTableSql()); err != nil { txn.Rollback() @@ -267,15 +249,8 @@ func createVersionTable(conf *DBConf, db *sql.DB) error { // wrapper for EnsureDBVersion for callers that don't already have // their own DB instance -func GetDBVersion(conf *DBConf) (version int64, err error) { - - db, err := OpenDBFromDBConf(conf) - if err != nil { - return -1, err - } - defer db.Close() - - version, err = EnsureDBVersion(conf, db) +func GetDBVersion(db *sql.DB) (int64, error) { + version, err := EnsureDBVersion(db) if err != nil { return -1, err } @@ -372,10 +347,10 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string, // Update the version table for the given migration, // and finalize the transaction. -func FinalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error { +func FinalizeMigration(txn *sql.Tx, direction bool, v int64) error { // XXX: drop goose_db_version table on some minimum version number? - stmt := conf.Driver.Dialect.insertVersionSql() + stmt := dialectByName("postgres").insertVersionSql() if _, err := txn.Exec(stmt, v, direction); err != nil { txn.Rollback() return err diff --git a/migration_sql.go b/migration_sql.go index 7dc0495..8c55ac4 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -135,7 +135,7 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { // // All statements following an Up or Down directive are grouped together // until another direction directive is found. -func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direction bool) error { +func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error { txn, err := db.Begin() if err != nil { @@ -160,7 +160,7 @@ func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direc } } - if err = FinalizeMigration(conf, txn, direction, v); err != nil { + if err = FinalizeMigration(txn, direction, v); err != nil { log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(scriptFile), err) } diff --git a/redo.go b/redo.go new file mode 100644 index 0000000..4b34117 --- /dev/null +++ b/redo.go @@ -0,0 +1,27 @@ +package goose + +import ( + "database/sql" +) + +func Redo(db *sql.DB, dir string) error { + current, err := GetDBVersion(db) + if err != nil { + return err + } + + previous, err := GetPreviousDBVersion(dir, current) + if err != nil { + return err + } + + if err := RunMigrations(db, dir, previous); err != nil { + return err + } + + if err := RunMigrations(db, dir, current); err != nil { + return err + } + + return nil +} diff --git a/status.go b/status.go new file mode 100644 index 0000000..a21ebaf --- /dev/null +++ b/status.go @@ -0,0 +1,53 @@ +package goose + +import ( + "database/sql" + "fmt" + "log" + "path/filepath" + "time" +) + +func Status(db *sql.DB, dir string) error { + // collect all migrations + min := int64(0) + max := int64((1 << 63) - 1) + migrations, err := CollectMigrations(dir, min, max) + if err != nil { + return err + } + + // must ensure that the version table exists if we're running on a pristine DB + if _, err := EnsureDBVersion(db); err != nil { + return err + } + + fmt.Println("goose: status") + fmt.Println(" Applied At Migration") + fmt.Println(" =======================================") + for _, m := range migrations { + printMigrationStatus(db, m.Version, filepath.Base(m.Source)) + } + + return nil +} + +func printMigrationStatus(db *sql.DB, version int64, script string) { + var row MigrationRecord + q := fmt.Sprintf("SELECT tstamp, is_applied FROM goose_db_version WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version) + e := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied) + + if e != nil && e != sql.ErrNoRows { + log.Fatal(e) + } + + var appliedAt string + + if row.IsApplied { + appliedAt = row.TStamp.Format(time.ANSIC) + } else { + appliedAt = "Pending" + } + + fmt.Printf(" %-24s -- %v\n", appliedAt, script) +} diff --git a/up.go b/up.go new file mode 100644 index 0000000..79c9807 --- /dev/null +++ b/up.go @@ -0,0 +1,17 @@ +package goose + +import ( + "database/sql" +) + +func Up(db *sql.DB, dir string) error { + target, err := GetMostRecentDBVersion(dir) + if err != nil { + return err + } + + if err := RunMigrations(db, dir, target); err != nil { + return err + } + return nil +}