dialect: use insertVersionSql() to ensure txn.Exec() args are handled properly in each dialect

pull/2/head
Liam Staskawicz 2013-07-08 23:37:30 -07:00
parent 2847f34016
commit bab8917da5
4 changed files with 33 additions and 25 deletions

View File

@ -37,11 +37,11 @@ func (pg *PostgresDialect) createVersionTableSql() string {
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`
);`
}
func (pg *PostgresDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);"
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
}
func (pg *PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
@ -70,11 +70,11 @@ func (m *MySqlDialect) createVersionTableSql() string {
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`
);`
}
func (m *MySqlDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);"
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
}
func (m *MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {

View File

@ -78,7 +78,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) {
case ".go":
e = runGoMigration(conf, m.Source, m.Version, mm.Direction)
case ".sql":
e = runSQLMigration(db, m.Source, m.Version, mm.Direction)
e = runSQLMigration(conf, db, m.Source, m.Version, mm.Direction)
}
if e != nil {
@ -89,7 +89,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) {
}
}
// collect all the valid looking migration scripts in the
// collect all the valid looking migration scripts in the
// migrations folder, and key them by version
func collectMigrations(dirpath string, current, target int64) (mm *MigrationMap, err error) {
@ -249,11 +249,17 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
}
d := conf.Driver.Dialect
for _, str := range []string{d.createVersionTableSql(), d.insertVersionSql()} {
if _, err := txn.Exec(str); err != nil {
txn.Rollback()
return err
}
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback()
return err
}
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
txn.Rollback()
return err
}
return txn.Commit()

View File

@ -11,10 +11,11 @@ import (
)
type TemplateData struct {
Version int64
Driver DBDriver
Direction bool
Func string
Version int64
Driver DBDriver
Direction bool
Func string
InsertStmt string
}
//
@ -39,10 +40,11 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er
}
td := &TemplateData{
Version: version,
Driver: conf.Driver,
Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version),
Version: version,
Driver: conf.Driver,
Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version),
InsertStmt: conf.Driver.Dialect.insertVersionSql(),
}
main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationTmpl, td)
if e != nil {
@ -93,7 +95,7 @@ func main() {
{{ .Func }}(txn)
// XXX: drop goose_db_version table on some minimum version number?
stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
stmt := "{{ .InsertStmt }}"
if _, err = txn.Exec(stmt, {{ .Version }}, {{ .Direction }}); err != nil {
txn.Rollback()
log.Fatal("failed to write version: ", err)

View File

@ -16,7 +16,7 @@ import (
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction bool) error {
txn, err := db.Begin()
if err != nil {
@ -71,7 +71,7 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
filepath.Base(script))
}
if err = finalizeMigration(txn, direction, v); err != nil {
if err = finalizeMigration(conf, txn, direction, v); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err)
}
@ -80,11 +80,11 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
// Update the version table for the given migration,
// and finalize the transaction.
func finalizeMigration(txn *sql.Tx, direction bool, v int64) error {
func finalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2)"
if _, err := txn.Exec(stmt, v, direction); err != nil {
d := conf.Driver.Dialect
if _, err := txn.Exec(d.insertVersionSql(), v, direction); err != nil {
txn.Rollback()
return err
}