dialect: use insertVersionSql() to ensure txn.Exec() args are handled properly in each dialect

pull/2/head
Liam Staskawicz 2013-07-08 23:37:30 -07:00
parent 2847f34016
commit bab8917da5
4 changed files with 33 additions and 25 deletions

View File

@ -37,11 +37,11 @@ func (pg *PostgresDialect) createVersionTableSql() string {
is_applied boolean NOT NULL, is_applied boolean NOT NULL,
tstamp timestamp NULL default now(), tstamp timestamp NULL default now(),
PRIMARY KEY(id) PRIMARY KEY(id)
);` );`
} }
func (pg *PostgresDialect) insertVersionSql() string { func (pg *PostgresDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
} }
func (pg *PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (pg *PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
@ -70,11 +70,11 @@ func (m *MySqlDialect) createVersionTableSql() string {
is_applied boolean NOT NULL, is_applied boolean NOT NULL,
tstamp timestamp NULL default now(), tstamp timestamp NULL default now(),
PRIMARY KEY(id) PRIMARY KEY(id)
);` );`
} }
func (m *MySqlDialect) insertVersionSql() string { func (m *MySqlDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (0, true);" return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
} }
func (m *MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (m *MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {

View File

@ -78,7 +78,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) {
case ".go": case ".go":
e = runGoMigration(conf, m.Source, m.Version, mm.Direction) e = runGoMigration(conf, m.Source, m.Version, mm.Direction)
case ".sql": case ".sql":
e = runSQLMigration(db, m.Source, m.Version, mm.Direction) e = runSQLMigration(conf, db, m.Source, m.Version, mm.Direction)
} }
if e != nil { if e != nil {
@ -89,7 +89,7 @@ func runMigrations(conf *DBConf, migrationsDir string, target int64) {
} }
} }
// collect all the valid looking migration scripts in the // collect all the valid looking migration scripts in the
// migrations folder, and key them by version // migrations folder, and key them by version
func collectMigrations(dirpath string, current, target int64) (mm *MigrationMap, err error) { func collectMigrations(dirpath string, current, target int64) (mm *MigrationMap, err error) {
@ -249,11 +249,17 @@ func createVersionTable(conf *DBConf, db *sql.DB) error {
} }
d := conf.Driver.Dialect d := conf.Driver.Dialect
for _, str := range []string{d.createVersionTableSql(), d.insertVersionSql()} {
if _, err := txn.Exec(str); err != nil { if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback() txn.Rollback()
return err return err
} }
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
txn.Rollback()
return err
} }
return txn.Commit() return txn.Commit()

View File

@ -11,10 +11,11 @@ import (
) )
type TemplateData struct { type TemplateData struct {
Version int64 Version int64
Driver DBDriver Driver DBDriver
Direction bool Direction bool
Func string Func string
InsertStmt string
} }
// //
@ -39,10 +40,11 @@ func runGoMigration(conf *DBConf, path string, version int64, direction bool) er
} }
td := &TemplateData{ td := &TemplateData{
Version: version, Version: version,
Driver: conf.Driver, Driver: conf.Driver,
Direction: direction, Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version), Func: fmt.Sprintf("%v_%v", directionStr, version),
InsertStmt: conf.Driver.Dialect.insertVersionSql(),
} }
main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationTmpl, td) main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationTmpl, td)
if e != nil { if e != nil {
@ -93,7 +95,7 @@ func main() {
{{ .Func }}(txn) {{ .Func }}(txn)
// XXX: drop goose_db_version table on some minimum version number? // XXX: drop goose_db_version table on some minimum version number?
stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" stmt := "{{ .InsertStmt }}"
if _, err = txn.Exec(stmt, {{ .Version }}, {{ .Direction }}); err != nil { if _, err = txn.Exec(stmt, {{ .Version }}, {{ .Direction }}); err != nil {
txn.Rollback() txn.Rollback()
log.Fatal("failed to write version: ", err) log.Fatal("failed to write version: ", err)

View File

@ -16,7 +16,7 @@ import (
// //
// All statements following an Up or Down directive are grouped together // All statements following an Up or Down directive are grouped together
// until another direction directive is found. // until another direction directive is found.
func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error { func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction bool) error {
txn, err := db.Begin() txn, err := db.Begin()
if err != nil { if err != nil {
@ -71,7 +71,7 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
filepath.Base(script)) filepath.Base(script))
} }
if err = finalizeMigration(txn, direction, v); err != nil { if err = finalizeMigration(conf, txn, direction, v); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err) log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err)
} }
@ -80,11 +80,11 @@ func runSQLMigration(db *sql.DB, script string, v int64, direction bool) error {
// Update the version table for the given migration, // Update the version table for the given migration,
// and finalize the transaction. // and finalize the transaction.
func finalizeMigration(txn *sql.Tx, direction bool, v int64) error { func finalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number? // XXX: drop goose_db_version table on some minimum version number?
stmt := "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2)" d := conf.Driver.Dialect
if _, err := txn.Exec(stmt, v, direction); err != nil { if _, err := txn.Exec(d.insertVersionSql(), v, direction); err != nil {
txn.Rollback() txn.Rollback()
return err return err
} }