mirror of https://github.com/pressly/goose.git
Refactor goose pkg
parent
0eaa95867a
commit
2cccd9df36
|
@ -34,7 +34,7 @@ func init() {
|
|||
// original .go migration, and execute it via `go run` along
|
||||
// with a main() of our own creation.
|
||||
//
|
||||
func runGoMigration(conf *DBConf, path string, version int64, direction bool) error {
|
||||
func runGoMigration(path string, version int64, direction bool) error {
|
||||
|
||||
// everything gets written to a temp dir, and zapped afterwards
|
||||
d, e := ioutil.TempDir("", "goose")
|
|
@ -1,22 +0,0 @@
|
|||
|
||||
test:
|
||||
driver: postgres
|
||||
open: user=liam dbname=tester sslmode=disable
|
||||
|
||||
development:
|
||||
driver: postgres
|
||||
open: user=liam dbname=tester sslmode=disable
|
||||
|
||||
production:
|
||||
driver: postgres
|
||||
open: user=liam dbname=tester sslmode=verify-full
|
||||
|
||||
customimport:
|
||||
driver: customdriver
|
||||
open: customdriver open
|
||||
import: github.com/custom/driver
|
||||
dialect: mysql
|
||||
|
||||
environment_variable_config:
|
||||
driver: $DB_DRIVER
|
||||
open: $DATABASE_URL
|
|
@ -1,15 +0,0 @@
|
|||
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func Up_20130106222315(txn *sql.Tx) {
|
||||
fmt.Println("Hello from migration 20130106222315 Up!")
|
||||
}
|
||||
|
||||
func Down_20130106222315(txn *sql.Tx) {
|
||||
fmt.Println("Hello from migration 20130106222315 Down!")
|
||||
}
|
139
dbconf.go
139
dbconf.go
|
@ -1,139 +0,0 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kylelemons/go-gypsy/yaml"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// DBDriver encapsulates the info needed to work with
|
||||
// a specific database driver
|
||||
type DBDriver struct {
|
||||
Name string
|
||||
OpenStr string
|
||||
Import string
|
||||
Dialect SqlDialect
|
||||
}
|
||||
|
||||
type DBConf struct {
|
||||
MigrationsDir string
|
||||
Env string
|
||||
Driver DBDriver
|
||||
PgSchema string
|
||||
}
|
||||
|
||||
// extract configuration details from the given file
|
||||
func NewDBConf(p, env string, pgschema string) (*DBConf, error) {
|
||||
|
||||
cfgFile := filepath.Join(p, "dbconf.yml")
|
||||
|
||||
f, err := yaml.ReadFile(cfgFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
drv, err := f.Get(fmt.Sprintf("%s.driver", env))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
drv = os.ExpandEnv(drv)
|
||||
|
||||
open, err := f.Get(fmt.Sprintf("%s.open", env))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
open = os.ExpandEnv(open)
|
||||
|
||||
// Automatically parse postgres urls
|
||||
if drv == "postgres" {
|
||||
|
||||
// Assumption: If we can parse the URL, we should
|
||||
if parsedURL, err := pq.ParseURL(open); err == nil && parsedURL != "" {
|
||||
open = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
d := newDBDriver(drv, open)
|
||||
|
||||
// allow the configuration to override the Import for this driver
|
||||
if imprt, err := f.Get(fmt.Sprintf("%s.import", env)); err == nil {
|
||||
d.Import = imprt
|
||||
}
|
||||
|
||||
// allow the configuration to override the Dialect for this driver
|
||||
if dialect, err := f.Get(fmt.Sprintf("%s.dialect", env)); err == nil {
|
||||
d.Dialect = dialectByName(dialect)
|
||||
}
|
||||
|
||||
if !d.IsValid() {
|
||||
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
|
||||
}
|
||||
|
||||
return &DBConf{
|
||||
MigrationsDir: filepath.Join(p, "migrations"),
|
||||
Env: env,
|
||||
Driver: d,
|
||||
PgSchema: pgschema,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create a new DBDriver and populate driver specific
|
||||
// fields for drivers that we know about.
|
||||
// Further customization may be done in NewDBConf
|
||||
func newDBDriver(name, open string) DBDriver {
|
||||
|
||||
d := DBDriver{
|
||||
Name: name,
|
||||
OpenStr: open,
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "postgres":
|
||||
d.Import = "github.com/lib/pq"
|
||||
d.Dialect = &PostgresDialect{}
|
||||
|
||||
case "mymysql":
|
||||
d.Import = "github.com/ziutek/mymysql/godrv"
|
||||
d.Dialect = &MySqlDialect{}
|
||||
|
||||
case "mysql":
|
||||
d.Import = "github.com/go-sql-driver/mysql"
|
||||
d.Dialect = &MySqlDialect{}
|
||||
|
||||
case "sqlite3":
|
||||
d.Import = "github.com/mattn/go-sqlite3"
|
||||
d.Dialect = &Sqlite3Dialect{}
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// ensure we have enough info about this driver
|
||||
func (drv *DBDriver) IsValid() bool {
|
||||
return len(drv.Import) > 0 && drv.Dialect != nil
|
||||
}
|
||||
|
||||
// OpenDBFromDBConf wraps database/sql.DB.Open() and configures
|
||||
// the newly opened DB based on the given DBConf.
|
||||
//
|
||||
// Callers must Close() the returned DB.
|
||||
func OpenDBFromDBConf(conf *DBConf) (*sql.DB, error) {
|
||||
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if a postgres schema has been specified, apply it
|
||||
if conf.Driver.Name == "postgres" && conf.PgSchema != "" {
|
||||
if _, err := db.Exec("SET search_path TO " + conf.PgSchema); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBasics(t *testing.T) {
|
||||
|
||||
dbconf, err := NewDBConf("../../db-sample", "test", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr}
|
||||
want := []string{"../../db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
|
||||
|
||||
for i, s := range got {
|
||||
if s != want[i] {
|
||||
t.Errorf("Unexpected DBConf value. got %v, want %v", s, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportOverride(t *testing.T) {
|
||||
|
||||
dbconf, err := NewDBConf("../../db-sample", "customimport", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := dbconf.Driver.Import
|
||||
want := "github.com/custom/driver"
|
||||
if got != want {
|
||||
t.Errorf("bad custom import. got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverSetFromEnvironmentVariable(t *testing.T) {
|
||||
|
||||
databaseUrlEnvVariableKey := "DB_DRIVER"
|
||||
databaseUrlEnvVariableVal := "sqlite3"
|
||||
databaseOpenStringKey := "DATABASE_URL"
|
||||
databaseOpenStringVal := "db.db"
|
||||
|
||||
os.Setenv(databaseUrlEnvVariableKey, databaseUrlEnvVariableVal)
|
||||
os.Setenv(databaseOpenStringKey, databaseOpenStringVal)
|
||||
|
||||
dbconf, err := NewDBConf("../../db-sample", "environment_variable_config", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := reflect.TypeOf(dbconf.Driver.Dialect)
|
||||
want := reflect.TypeOf(&Sqlite3Dialect{})
|
||||
|
||||
if got != want {
|
||||
t.Errorf("Not able to read the driver type from environment variable."+
|
||||
"got %v want %v", got, want)
|
||||
}
|
||||
|
||||
gotOpenString := dbconf.Driver.OpenStr
|
||||
wantOpenString := databaseOpenStringVal
|
||||
|
||||
if gotOpenString != wantOpenString {
|
||||
t.Errorf("Not able to read the open string from the environment."+
|
||||
"got %v want %v", gotOpenString, wantOpenString)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func Down(db *sql.DB, dir string) error {
|
||||
current, err := GetDBVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
previous, err := GetPreviousDBVersion(dir, current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = RunMigrations(db, dir, previous); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
-- +goose Up
|
||||
CREATE TABLE post (
|
||||
id int NOT NULL,
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
-- +goose Up
|
||||
CREATE TABLE fancier_post (
|
||||
id int NOT NULL,
|
|
@ -0,0 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func Up_20130106222315(txn *sql.Tx) {
|
||||
fmt.Println("Hello from migration 20130106222315 Up!")
|
||||
}
|
||||
|
||||
func Down_20130106222315(txn *sql.Tx) {
|
||||
fmt.Println("Hello from migration 20130106222315 Down!")
|
||||
}
|
57
migrate.go
57
migrate.go
|
@ -12,11 +12,6 @@ import (
|
|||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/ziutek/mymysql/godrv"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -48,25 +43,13 @@ func newMigration(v int64, src string) *Migration {
|
|||
return &Migration{v, -1, -1, src}
|
||||
}
|
||||
|
||||
func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) {
|
||||
|
||||
db, err := OpenDBFromDBConf(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
return RunMigrationsOnDb(conf, migrationsDir, target, db)
|
||||
}
|
||||
|
||||
// Runs migration on a specific database instance.
|
||||
func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql.DB) (err error) {
|
||||
current, err := EnsureDBVersion(conf, db)
|
||||
func RunMigrations(db *sql.DB, dir string, target int64) (err error) {
|
||||
current, err := EnsureDBVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrations, err := CollectMigrations(migrationsDir, current, target)
|
||||
migrations, err := CollectMigrations(dir, current, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -80,16 +63,15 @@ func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql
|
|||
direction := current < target
|
||||
ms.Sort(direction)
|
||||
|
||||
fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n",
|
||||
conf.Env, current, target)
|
||||
fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n", current, target)
|
||||
|
||||
for _, m := range ms {
|
||||
|
||||
switch filepath.Ext(m.Source) {
|
||||
case ".go":
|
||||
err = runGoMigration(conf, m.Source, m.Version, direction)
|
||||
// case ".go":
|
||||
// err = runGoMigration(m.Source, m.Version, direction)
|
||||
case ".sql":
|
||||
err = runSQLMigration(conf, db, m.Source, m.Version, direction)
|
||||
err = runSQLMigration(db, m.Source, m.Version, direction)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -192,12 +174,12 @@ func NumericComponent(name string) (int64, error) {
|
|||
|
||||
// retrieve the current version for this DB.
|
||||
// Create and initialize the DB version table if it doesn't exist.
|
||||
func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
|
||||
func EnsureDBVersion(db *sql.DB) (int64, error) {
|
||||
|
||||
rows, err := conf.Driver.Dialect.dbVersionQuery(db)
|
||||
rows, err := dialectByName("postgres").dbVersionQuery(db)
|
||||
if err != nil {
|
||||
if err == ErrTableDoesNotExist {
|
||||
return 0, createVersionTable(conf, db)
|
||||
return 0, createVersionTable(db)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
@ -242,13 +224,13 @@ func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
|
|||
|
||||
// Create the goose_db_version table
|
||||
// and insert the initial 0 value into it
|
||||
func createVersionTable(conf *DBConf, db *sql.DB) error {
|
||||
func createVersionTable(db *sql.DB) error {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d := conf.Driver.Dialect
|
||||
d := dialectByName("postgres")
|
||||
|
||||
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
|
||||
txn.Rollback()
|
||||
|
@ -267,15 +249,8 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
|
|||
|
||||
// wrapper for EnsureDBVersion for callers that don't already have
|
||||
// their own DB instance
|
||||
func GetDBVersion(conf *DBConf) (version int64, err error) {
|
||||
|
||||
db, err := OpenDBFromDBConf(conf)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
version, err = EnsureDBVersion(conf, db)
|
||||
func GetDBVersion(db *sql.DB) (int64, error) {
|
||||
version, err := EnsureDBVersion(db)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
@ -372,10 +347,10 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string,
|
|||
|
||||
// Update the version table for the given migration,
|
||||
// and finalize the transaction.
|
||||
func FinalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error {
|
||||
func FinalizeMigration(txn *sql.Tx, direction bool, v int64) error {
|
||||
|
||||
// XXX: drop goose_db_version table on some minimum version number?
|
||||
stmt := conf.Driver.Dialect.insertVersionSql()
|
||||
stmt := dialectByName("postgres").insertVersionSql()
|
||||
if _, err := txn.Exec(stmt, v, direction); err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
|
|
|
@ -135,7 +135,7 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
|
|||
//
|
||||
// All statements following an Up or Down directive are grouped together
|
||||
// until another direction directive is found.
|
||||
func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direction bool) error {
|
||||
func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error {
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
|
@ -160,7 +160,7 @@ func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direc
|
|||
}
|
||||
}
|
||||
|
||||
if err = FinalizeMigration(conf, txn, direction, v); err != nil {
|
||||
if err = FinalizeMigration(txn, direction, v); err != nil {
|
||||
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(scriptFile), err)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func Redo(db *sql.DB, dir string) error {
|
||||
current, err := GetDBVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
previous, err := GetPreviousDBVersion(dir, current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RunMigrations(db, dir, previous); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RunMigrations(db, dir, current); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Status(db *sql.DB, dir string) error {
|
||||
// collect all migrations
|
||||
min := int64(0)
|
||||
max := int64((1 << 63) - 1)
|
||||
migrations, err := CollectMigrations(dir, min, max)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// must ensure that the version table exists if we're running on a pristine DB
|
||||
if _, err := EnsureDBVersion(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("goose: status")
|
||||
fmt.Println(" Applied At Migration")
|
||||
fmt.Println(" =======================================")
|
||||
for _, m := range migrations {
|
||||
printMigrationStatus(db, m.Version, filepath.Base(m.Source))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printMigrationStatus(db *sql.DB, version int64, script string) {
|
||||
var row MigrationRecord
|
||||
q := fmt.Sprintf("SELECT tstamp, is_applied FROM goose_db_version WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version)
|
||||
e := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied)
|
||||
|
||||
if e != nil && e != sql.ErrNoRows {
|
||||
log.Fatal(e)
|
||||
}
|
||||
|
||||
var appliedAt string
|
||||
|
||||
if row.IsApplied {
|
||||
appliedAt = row.TStamp.Format(time.ANSIC)
|
||||
} else {
|
||||
appliedAt = "Pending"
|
||||
}
|
||||
|
||||
fmt.Printf(" %-24s -- %v\n", appliedAt, script)
|
||||
}
|
Loading…
Reference in New Issue