dialect: use insertVersionSql() to ensure txn.Exec() args are handled properly in each dialect

pull/2/head
Liam Staskawicz 2013-07-08 23:37:30 -07:00
parent 2847f34016
commit bab8917da5
4 changed files with 33 additions and 25 deletions

View File

@ -41,7 +41,7 @@ func (pg *PostgresDialect) createVersionTableSql() string {
}
func (pg *PostgresDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);"
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
}
func (pg *PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
@ -74,7 +74,7 @@ func (m *MySqlDialect) createVersionTableSql() string {
}
func (m *MySqlDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);"
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
}
func (m *MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {

View File

@ -78,7 +78,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) {
case ".go":
e = runGoMigration(conf, m.Source, m.Version, mm.Direction)
case ".sql":
e = runSQLMigration(db, m.Source, m.Version, mm.Direction)
e = runSQLMigration(conf, db, m.Source, m.Version, mm.Direction)
}
if e != nil {
@ -249,11 +249,17 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
}
d := conf.Driver.Dialect
for _, str := range []string{d.createVersionTableSql(), d.insertVersionSql()} {
if _, err := txn.Exec(str); err != nil {
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback()
return err
}
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
txn.Rollback()
return err
}
return txn.Commit()

View File

@ -15,6 +15,7 @@ type TemplateData struct {
Driver DBDriver
Direction bool
Func string
InsertStmt string
}
//
@ -43,6 +44,7 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er
Driver: conf.Driver,
Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version),
InsertStmt: conf.Driver.Dialect.insertVersionSql(),
}
main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationTmpl, td)
if e != nil {
@ -93,7 +95,7 @@ func main() {
{{ .Func }}(txn)
// XXX: drop goose_db_version table on some minimum version number?
stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
stmt := "{{ .InsertStmt }}"
if _, err = txn.Exec(stmt, {{ .Version }}, {{ .Direction }}); err != nil {
txn.Rollback()
log.Fatal("failed to write version: ", err)

View File

@ -16,7 +16,7 @@ import (
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction bool) error {
txn, err := db.Begin()
if err != nil {
@ -71,7 +71,7 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
filepath.Base(script))
}
if err = finalizeMigration(txn, direction, v); err != nil {
if err = finalizeMigration(conf, txn, direction, v); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err)
}
@ -80,11 +80,11 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
// Update the version table for the given migration,
// and finalize the transaction.
func finalizeMigration(txn *sql.Tx, direction bool, v int64) error {
func finalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2)"
if _, err := txn.Exec(stmt, v, direction); err != nil {
d := conf.Driver.Dialect
if _, err := txn.Exec(d.insertVersionSql(), v, direction); err != nil {
txn.Rollback()
return err
}