diff --git a/README.md b/README.md index 48f3996..8583a9c 100644 --- a/README.md +++ b/README.md @@ -128,14 +128,17 @@ You may include as many environments as you like, and you can use the `-env` com goose will expand environment variables in the `open` element. For an example, see the Heroku section below. ## Other Drivers -goose knows about some common SQL drivers, but it can still be used to run Go-based migrations with any driver supported by database/sql. +goose knows about some common SQL drivers, but it can still be used to run Go-based migrations with any driver supported by database/sql. An import path and known dialect are required. -To run Go-based migrations with another driver, specify its import path, as shown below. +Currently, available dialects are: "postgres" or "mysql" + +To run Go-based migrations with another driver, specify its import path and dialect, as shown below. customdriver: driver: custom open: custom open string import: github.com/custom/driver + dialect: mysql NOTE: Because migrations written in SQL are executed directly by the goose binary, only drivers compiled into goose may be used for these migrations. diff --git a/db-sample/dbconf.yml b/db-sample/dbconf.yml index 52c355c..65672fc 100644 --- a/db-sample/dbconf.yml +++ b/db-sample/dbconf.yml @@ -15,3 +15,4 @@ customimport: driver: customdriver open: customdriver open import: github.com/custom/driver + dialect: mysql diff --git a/dbconf.go b/dbconf.go index ce929b2..4cf17fd 100644 --- a/dbconf.go +++ b/dbconf.go @@ -20,6 +20,7 @@ type DBDriver struct { Name string OpenStr string Import string + Dialect SqlDialect } type DBConf struct { @@ -70,6 +71,11 @@ func newDBConfDetails(p, env string) (*DBConf, error) { d.Import = imprt } + // allow the configuration to override the Dialect for this driver + if dialect, err := f.Get(fmt.Sprintf("%s.dialect", env)); err == nil { + d.Dialect = DialectByName(dialect) + } + if !d.IsValid() { return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d)) } @@ -94,9 +100,11 @@ func NewDBDriver(name, open string) DBDriver { switch name { case "postgres": d.Import = "github.com/lib/pq" + d.Dialect = &PostgresDialect{} case "mymysql": d.Import = "github.com/ziutek/mymysql/godrv" + d.Dialect = &MySqlDialect{} } return d @@ -108,5 +116,9 @@ func (drv *DBDriver) IsValid() bool { return false } + if drv.Dialect == nil { + return false + } + return true } diff --git a/dbconf_test.go b/dbconf_test.go index 457fad9..dcf6a2c 100644 --- a/dbconf_test.go +++ b/dbconf_test.go @@ -9,7 +9,7 @@ func TestBasics(t *testing.T) { dbconf, err := newDBConfDetails("db-sample", "test") if err != nil { - t.Error("couldn't create DBConf") + t.Fatal(err) } got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr} @@ -26,7 +26,7 @@ func TestImportOverride(t *testing.T) { dbconf, err := newDBConfDetails("db-sample", "customimport") if err != nil { - t.Error("couldn't create DBConf") + t.Fatal(err) } got := dbconf.Driver.Import diff --git a/dialect.go b/dialect.go new file mode 100644 index 0000000..c3a14f3 --- /dev/null +++ b/dialect.go @@ -0,0 +1,91 @@ +package main + +import ( + "database/sql" +) + +// SqlDialect abstracts the details of specific SQL dialects +// for goose's few SQL specific statements +type SqlDialect interface { + createVersionTableSql() string // sql string to create the goose_db_version table + insertVersionSql() string // sql string to insert the initial version table row + dbVersionQuery(db *sql.DB) (*sql.Rows, error) +} + +// drivers that we don't know about can ask for a dialect by name +func DialectByName(d string) SqlDialect { + switch d { + case "postgres": + return &PostgresDialect{} + case "mysql": + return &MySqlDialect{} + } + + return nil +} + +//////////////////////////// +// Postgres +//////////////////////////// + +type PostgresDialect struct{} + +func (pg *PostgresDialect) createVersionTableSql() string { + return `CREATE TABLE goose_db_version ( + id serial NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + );` +} + +func (pg *PostgresDialect) insertVersionSql() string { + return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" +} + +func (pg *PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC") + + // XXX: check for postgres specific error indicating the table doesn't exist. + // for now, assume any error is because the table doesn't exist, + // in which case we'll try to create it. + if err != nil { + return nil, ErrTableDoesNotExist + } + + return rows, err +} + +//////////////////////////// +// MySQL +//////////////////////////// + +type MySqlDialect struct{} + +func (m *MySqlDialect) createVersionTableSql() string { + return `CREATE TABLE goose_db_version ( + id serial NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + );` +} + +func (m *MySqlDialect) insertVersionSql() string { + return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" +} + +func (m *MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC") + + // XXX: check for mysql specific error indicating the table doesn't exist. + // for now, assume any error is because the table doesn't exist, + // in which case we'll try to create it. + if err != nil { + return nil, ErrTableDoesNotExist + } + + return rows, err +} diff --git a/migrate.go b/migrate.go index 16ce0fd..842da76 100644 --- a/migrate.go +++ b/migrate.go @@ -15,6 +15,8 @@ import ( "time" ) +var ErrTableDoesNotExist = errors.New("table does not exist") + type MigrationRecord struct { VersionId int64 TStamp time.Time @@ -192,11 +194,14 @@ func numericComponent(name string) (int64, error) { // Create and initialize the DB version table if it doesn't exist. func ensureDBVersion(conf *DBConf, db *sql.DB) (int64, error) { - rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC;") + rows, err := conf.Driver.Dialect.dbVersionQuery(db) if err != nil { - // XXX: cross platform method to detect failure reason - // for now, assume it was because the table didn't exist, and try to create it - return 0, createVersionTable(conf, db) + + if err == ErrTableDoesNotExist { + return 0, createVersionTable(conf, db) + } + + return 0, err } // The most recent record for each migration specifies @@ -243,17 +248,8 @@ func createVersionTable(conf *DBConf, db *sql.DB) error { return err } - // create the table and insert an initial value of 0 - create := `CREATE TABLE goose_db_version ( - id serial NOT NULL, - version_id bigint NOT NULL, - is_applied boolean NOT NULL, - tstamp timestamp NULL default now(), - PRIMARY KEY(id) - );` - insert := "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" - - for _, str := range []string{create, insert} { + d := conf.Driver.Dialect + for _, str := range []string{d.createVersionTableSql(), d.insertVersionSql()} { if _, err := txn.Exec(str); err != nil { txn.Rollback() return err