Refactor goose pkg

pull/1/head
Vojtech Vitek (V-Teq) 2016-03-02 17:23:15 -05:00
parent 0eaa95867a
commit 2cccd9df36
14 changed files with 153 additions and 292 deletions

View File

@ -34,7 +34,7 @@ func init() {
// original .go migration, and execute it via `go run` along
// with a main() of our own creation.
//
func runGoMigration(conf *DBConf, path string, version int64, direction bool) error {
func runGoMigration(path string, version int64, direction bool) error {
// everything gets written to a temp dir, and zapped afterwards
d, e := ioutil.TempDir("", "goose")

View File

@ -1,22 +0,0 @@
test:
driver: postgres
open: user=liam dbname=tester sslmode=disable
development:
driver: postgres
open: user=liam dbname=tester sslmode=disable
production:
driver: postgres
open: user=liam dbname=tester sslmode=verify-full
customimport:
driver: customdriver
open: customdriver open
import: github.com/custom/driver
dialect: mysql
environment_variable_config:
driver: $DB_DRIVER
open: $DATABASE_URL

View File

@ -1,15 +0,0 @@
package main
import (
"database/sql"
"fmt"
)
func Up_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Up!")
}
func Down_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Down!")
}

139
dbconf.go
View File

@ -1,139 +0,0 @@
package goose
import (
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"github.com/kylelemons/go-gypsy/yaml"
"github.com/lib/pq"
)
// DBDriver encapsulates the info needed to work with
// a specific database driver
type DBDriver struct {
Name string
OpenStr string
Import string
Dialect SqlDialect
}
type DBConf struct {
MigrationsDir string
Env string
Driver DBDriver
PgSchema string
}
// extract configuration details from the given file
func NewDBConf(p, env string, pgschema string) (*DBConf, error) {
cfgFile := filepath.Join(p, "dbconf.yml")
f, err := yaml.ReadFile(cfgFile)
if err != nil {
return nil, err
}
drv, err := f.Get(fmt.Sprintf("%s.driver", env))
if err != nil {
return nil, err
}
drv = os.ExpandEnv(drv)
open, err := f.Get(fmt.Sprintf("%s.open", env))
if err != nil {
return nil, err
}
open = os.ExpandEnv(open)
// Automatically parse postgres urls
if drv == "postgres" {
// Assumption: If we can parse the URL, we should
if parsedURL, err := pq.ParseURL(open); err == nil && parsedURL != "" {
open = parsedURL
}
}
d := newDBDriver(drv, open)
// allow the configuration to override the Import for this driver
if imprt, err := f.Get(fmt.Sprintf("%s.import", env)); err == nil {
d.Import = imprt
}
// allow the configuration to override the Dialect for this driver
if dialect, err := f.Get(fmt.Sprintf("%s.dialect", env)); err == nil {
d.Dialect = dialectByName(dialect)
}
if !d.IsValid() {
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
}
return &DBConf{
MigrationsDir: filepath.Join(p, "migrations"),
Env: env,
Driver: d,
PgSchema: pgschema,
}, nil
}
// Create a new DBDriver and populate driver specific
// fields for drivers that we know about.
// Further customization may be done in NewDBConf
func newDBDriver(name, open string) DBDriver {
d := DBDriver{
Name: name,
OpenStr: open,
}
switch name {
case "postgres":
d.Import = "github.com/lib/pq"
d.Dialect = &PostgresDialect{}
case "mymysql":
d.Import = "github.com/ziutek/mymysql/godrv"
d.Dialect = &MySqlDialect{}
case "mysql":
d.Import = "github.com/go-sql-driver/mysql"
d.Dialect = &MySqlDialect{}
case "sqlite3":
d.Import = "github.com/mattn/go-sqlite3"
d.Dialect = &Sqlite3Dialect{}
}
return d
}
// ensure we have enough info about this driver
func (drv *DBDriver) IsValid() bool {
return len(drv.Import) > 0 && drv.Dialect != nil
}
// OpenDBFromDBConf wraps database/sql.DB.Open() and configures
// the newly opened DB based on the given DBConf.
//
// Callers must Close() the returned DB.
func OpenDBFromDBConf(conf *DBConf) (*sql.DB, error) {
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
if err != nil {
return nil, err
}
// if a postgres schema has been specified, apply it
if conf.Driver.Name == "postgres" && conf.PgSchema != "" {
if _, err := db.Exec("SET search_path TO " + conf.PgSchema); err != nil {
return nil, err
}
}
return db, nil
}

View File

@ -1,70 +0,0 @@
package goose
import (
"os"
"reflect"
"testing"
)
func TestBasics(t *testing.T) {
dbconf, err := NewDBConf("../../db-sample", "test", "")
if err != nil {
t.Fatal(err)
}
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr}
want := []string{"../../db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
for i, s := range got {
if s != want[i] {
t.Errorf("Unexpected DBConf value. got %v, want %v", s, want[i])
}
}
}
func TestImportOverride(t *testing.T) {
dbconf, err := NewDBConf("../../db-sample", "customimport", "")
if err != nil {
t.Fatal(err)
}
got := dbconf.Driver.Import
want := "github.com/custom/driver"
if got != want {
t.Errorf("bad custom import. got %v want %v", got, want)
}
}
func TestDriverSetFromEnvironmentVariable(t *testing.T) {
databaseUrlEnvVariableKey := "DB_DRIVER"
databaseUrlEnvVariableVal := "sqlite3"
databaseOpenStringKey := "DATABASE_URL"
databaseOpenStringVal := "db.db"
os.Setenv(databaseUrlEnvVariableKey, databaseUrlEnvVariableVal)
os.Setenv(databaseOpenStringKey, databaseOpenStringVal)
dbconf, err := NewDBConf("../../db-sample", "environment_variable_config", "")
if err != nil {
t.Fatal(err)
}
got := reflect.TypeOf(dbconf.Driver.Dialect)
want := reflect.TypeOf(&Sqlite3Dialect{})
if got != want {
t.Errorf("Not able to read the driver type from environment variable."+
"got %v want %v", got, want)
}
gotOpenString := dbconf.Driver.OpenStr
wantOpenString := databaseOpenStringVal
if gotOpenString != wantOpenString {
t.Errorf("Not able to read the open string from the environment."+
"got %v want %v", gotOpenString, wantOpenString)
}
}

23
down.go Normal file
View File

@ -0,0 +1,23 @@
package goose
import (
"database/sql"
)
func Down(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
if err != nil {
return err
}
previous, err := GetPreviousDBVersion(dir, current)
if err != nil {
return err
}
if err = RunMigrations(db, dir, previous); err != nil {
return err
}
return nil
}

View File

@ -1,4 +1,3 @@
-- +goose Up
CREATE TABLE post (
id int NOT NULL,

View File

@ -1,4 +1,3 @@
-- +goose Up
CREATE TABLE fancier_post (
id int NOT NULL,

View File

@ -0,0 +1,14 @@
package main
import (
"database/sql"
"fmt"
)
func Up_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Up!")
}
func Down_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Down!")
}

View File

@ -12,11 +12,6 @@ import (
"strings"
"text/template"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
)
var (
@ -48,25 +43,13 @@ func newMigration(v int64, src string) *Migration {
return &Migration{v, -1, -1, src}
}
func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) {
db, err := OpenDBFromDBConf(conf)
if err != nil {
return err
}
defer db.Close()
return RunMigrationsOnDb(conf, migrationsDir, target, db)
}
// Runs migration on a specific database instance.
func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql.DB) (err error) {
current, err := EnsureDBVersion(conf, db)
func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
current, err := EnsureDBVersion(db)
if err != nil {
return err
}
migrations, err := CollectMigrations(migrationsDir, current, target)
migrations, err := CollectMigrations(dir, current, target)
if err != nil {
return err
}
@ -80,16 +63,15 @@ func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql
direction := current < target
ms.Sort(direction)
fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n",
conf.Env, current, target)
fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n", current, target)
for _, m := range ms {
switch filepath.Ext(m.Source) {
case ".go":
err = runGoMigration(conf, m.Source, m.Version, direction)
// case ".go":
// err = runGoMigration(m.Source, m.Version, direction)
case ".sql":
err = runSQLMigration(conf, db, m.Source, m.Version, direction)
err = runSQLMigration(db, m.Source, m.Version, direction)
}
if err != nil {
@ -192,12 +174,12 @@ func NumericComponent(name string) (int64, error) {
// retrieve the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
func EnsureDBVersion(db *sql.DB) (int64, error) {
rows, err := conf.Driver.Dialect.dbVersionQuery(db)
rows, err := dialectByName("postgres").dbVersionQuery(db)
if err != nil {
if err == ErrTableDoesNotExist {
return 0, createVersionTable(conf, db)
return 0, createVersionTable(db)
}
return 0, err
}
@ -242,13 +224,13 @@ func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
// Create the goose_db_version table
// and insert the initial 0 value into it
func createVersionTable(conf *DBConf, db *sql.DB) error {
func createVersionTable(db *sql.DB) error {
txn, err := db.Begin()
if err != nil {
return err
}
d := conf.Driver.Dialect
d := dialectByName("postgres")
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback()
@ -267,15 +249,8 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
// wrapper for EnsureDBVersion for callers that don't already have
// their own DB instance
func GetDBVersion(conf *DBConf) (version int64, err error) {
db, err := OpenDBFromDBConf(conf)
if err != nil {
return -1, err
}
defer db.Close()
version, err = EnsureDBVersion(conf, db)
func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db)
if err != nil {
return -1, err
}
@ -372,10 +347,10 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string,
// Update the version table for the given migration,
// and finalize the transaction.
func FinalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error {
func FinalizeMigration(txn *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := conf.Driver.Dialect.insertVersionSql()
stmt := dialectByName("postgres").insertVersionSql()
if _, err := txn.Exec(stmt, v, direction); err != nil {
txn.Rollback()
return err

View File

@ -135,7 +135,7 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direction bool) error {
func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error {
txn, err := db.Begin()
if err != nil {
@ -160,7 +160,7 @@ func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direc
}
}
if err = FinalizeMigration(conf, txn, direction, v); err != nil {
if err = FinalizeMigration(txn, direction, v); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(scriptFile), err)
}

27
redo.go Normal file
View File

@ -0,0 +1,27 @@
package goose
import (
"database/sql"
)
func Redo(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
if err != nil {
return err
}
previous, err := GetPreviousDBVersion(dir, current)
if err != nil {
return err
}
if err := RunMigrations(db, dir, previous); err != nil {
return err
}
if err := RunMigrations(db, dir, current); err != nil {
return err
}
return nil
}

53
status.go Normal file
View File

@ -0,0 +1,53 @@
package goose
import (
"database/sql"
"fmt"
"log"
"path/filepath"
"time"
)
func Status(db *sql.DB, dir string) error {
// collect all migrations
min := int64(0)
max := int64((1 << 63) - 1)
migrations, err := CollectMigrations(dir, min, max)
if err != nil {
return err
}
// must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil {
return err
}
fmt.Println("goose: status")
fmt.Println(" Applied At Migration")
fmt.Println(" =======================================")
for _, m := range migrations {
printMigrationStatus(db, m.Version, filepath.Base(m.Source))
}
return nil
}
func printMigrationStatus(db *sql.DB, version int64, script string) {
var row MigrationRecord
q := fmt.Sprintf("SELECT tstamp, is_applied FROM goose_db_version WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version)
e := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied)
if e != nil && e != sql.ErrNoRows {
log.Fatal(e)
}
var appliedAt string
if row.IsApplied {
appliedAt = row.TStamp.Format(time.ANSIC)
} else {
appliedAt = "Pending"
}
fmt.Printf(" %-24s -- %v\n", appliedAt, script)
}

17
up.go Normal file
View File

@ -0,0 +1,17 @@
package goose
import (
"database/sql"
)
func Up(db *sql.DB, dir string) error {
target, err := GetMostRecentDBVersion(dir)
if err != nil {
return err
}
if err := RunMigrations(db, dir, target); err != nil {
return err
}
return nil
}