diff --git a/dialect.go b/dialect.go index 1f21aa6..e79ad81 100644 --- a/dialect.go +++ b/dialect.go @@ -11,6 +11,7 @@ type SQLDialect interface { createVersionTableSQL() string // sql string to create the db version table insertVersionSQL() string // sql string to insert the initial version table row updateVersionSQL() string // sql string to update version + deleteVersionSQL() string // sql string to delete version dbVersionQuery(db *sql.DB) (*sql.Rows, error) } @@ -75,6 +76,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (pg PostgresDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +} + //////////////////////////// // MySQL //////////////////////////// @@ -109,6 +114,10 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m MySQLDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +} + //////////////////////////// // sqlite3 //////////////////////////// @@ -142,6 +151,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m Sqlite3Dialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +} + //////////////////////////// // Redshift //////////////////////////// @@ -176,6 +189,10 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (rs RedshiftDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +} + //////////////////////////// // TiDB //////////////////////////// @@ -209,3 +226,7 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } + +func (m TiDBDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) +} diff --git a/migration.go b/migration.go index 595b54a..853b72c 100644 --- a/migration.go +++ b/migration.go @@ -77,9 +77,17 @@ func (m *Migration) run(db *sql.DB, direction bool) error { return err } } - if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil { - tx.Rollback() - return err + + if direction { + if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil { + tx.Rollback() + return err + } + } else { + if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil { + tx.Rollback() + return err + } } return tx.Commit() diff --git a/migration_sql.go b/migration_sql.go index 6bea400..0f4f48e 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -164,9 +164,17 @@ func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) err return err } } - if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil { - tx.Rollback() - return err + + if direction { + if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil { + tx.Rollback() + return err + } + } else { + if _, err := tx.Exec(GetDialect().deleteVersionSQL(), v); err != nil { + tx.Rollback() + return err + } } return tx.Commit()