diff --git a/dialect.go b/dialect.go index c3a14f3..c5a0d27 100644 --- a/dialect.go +++ b/dialect.go @@ -37,11 +37,11 @@ func (pg *PostgresDialect) createVersionTableSql() string { is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) - );` + );` } func (pg *PostgresDialect) insertVersionSql() string { - return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" + return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" } func (pg *PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { @@ -70,11 +70,11 @@ func (m *MySqlDialect) createVersionTableSql() string { is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) - );` + );` } func (m *MySqlDialect) insertVersionSql() string { - return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" + return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);" } func (m *MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { diff --git a/migrate.go b/migrate.go index 842da76..17868bc 100644 --- a/migrate.go +++ b/migrate.go @@ -78,7 +78,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) { case ".go": e = runGoMigration(conf, m.Source, m.Version, mm.Direction) case ".sql": - e = runSQLMigration(db, m.Source, m.Version, mm.Direction) + e = runSQLMigration(conf, db, m.Source, m.Version, mm.Direction) } if e != nil { @@ -89,7 +89,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) { } } -// collect all the valid looking migration scripts in the +// collect all the valid looking migration scripts in the // migrations folder, and key them by version func collectMigrations(dirpath string, current, target int64) (mm *MigrationMap, err error) { @@ -249,11 +249,17 @@ func createVersionTable(conf *DBConf, db *sql.DB) error { } d := conf.Driver.Dialect - for _, str := range []string{d.createVersionTableSql(), d.insertVersionSql()} { - if _, err := txn.Exec(str); err != nil { - txn.Rollback() - return err - } + + if _, err := txn.Exec(d.createVersionTableSql()); err != nil { + txn.Rollback() + return err + } + + version := 0 + applied := true + if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil { + txn.Rollback() + return err } return txn.Commit() diff --git a/migration_go.go b/migration_go.go index 4f0d525..3790934 100644 --- a/migration_go.go +++ b/migration_go.go @@ -11,10 +11,11 @@ import ( ) type TemplateData struct { - Version int64 - Driver DBDriver - Direction bool - Func string + Version int64 + Driver DBDriver + Direction bool + Func string + InsertStmt string } // @@ -39,10 +40,11 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er } td := &TemplateData{ - Version: version, - Driver: conf.Driver, - Direction: direction, - Func: fmt.Sprintf("%v_%v", directionStr, version), + Version: version, + Driver: conf.Driver, + Direction: direction, + Func: fmt.Sprintf("%v_%v", directionStr, version), + InsertStmt: conf.Driver.Dialect.insertVersionSql(), } main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationTmpl, td) if e != nil { @@ -93,7 +95,7 @@ func main() { {{ .Func }}(txn) // XXX: drop goose_db_version table on some minimum version number? - stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" + stmt := "{{ .InsertStmt }}" if _, err = txn.Exec(stmt, {{ .Version }}, {{ .Direction }}); err != nil { txn.Rollback() log.Fatal("failed to write version: ", err) diff --git a/migration_sql.go b/migration_sql.go index fbaf3a7..d2a2f2e 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -16,7 +16,7 @@ import ( // // All statements following an Up or Down directive are grouped together // until another direction directive is found. -func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error { +func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction bool) error { txn, err := db.Begin() if err != nil { @@ -71,7 +71,7 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error { filepath.Base(script)) } - if err = finalizeMigration(txn, direction, v); err != nil { + if err = finalizeMigration(conf, txn, direction, v); err != nil { log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err) } @@ -80,11 +80,11 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error { // Update the version table for the given migration, // and finalize the transaction. -func finalizeMigration(txn *sql.Tx, direction bool, v int64) error { +func finalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error { // XXX: drop goose_db_version table on some minimum version number? - stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2)" - if _, err := txn.Exec(stmt, v, direction); err != nil { + d := conf.Driver.Dialect + if _, err := txn.Exec(d.insertVersionSql(), v, direction); err != nil { txn.Rollback() return err }