Fix migrations w/o TX, refactor

pull/56/head
Vojtech Vitek 2017-06-20 16:30:29 -04:00
parent 09b1a1b116
commit a26643fb2b
3 changed files with 48 additions and 106 deletions

View File

@ -98,6 +98,7 @@ func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.T
if existing, ok := registeredGoMigrations[v]; ok { if existing, ok := registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
} }
registeredGoMigrations[v] = migration registeredGoMigrations[v] = migration
} }
@ -141,12 +142,14 @@ func CollectMigrations(dirpath string, current, target int64) (Migrations, error
for _, file := range goMigrationFiles { for _, file := range goMigrationFiles {
v, err := NumericComponent(file) v, err := NumericComponent(file)
if err != nil { if err != nil {
continue // Skip any files that don't have start with version. continue // Skip any files that don't have version prefix.
} }
// Skip migrations already registered via goose.AddMigration().
// Skip migrations already existing migrations registered via goose.AddMigration().
if _, ok := registeredGoMigrations[v]; ok { if _, ok := registeredGoMigrations[v]; ok {
continue continue
} }
if versionFilter(v, current, target) { if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false} migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false}
migrations = append(migrations, migration) migrations = append(migrations, migration)

View File

@ -35,12 +35,20 @@ func (m *Migration) String() string {
// Up runs an up migration. // Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error { func (m *Migration) Up(db *sql.DB) error {
return m.run(db, true) if err := m.run(db, true); err != nil {
return err
}
fmt.Println("OK ", filepath.Base(m.Source))
return nil
} }
// Down runs a down migration. // Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error { func (m *Migration) Down(db *sql.DB) error {
return m.run(db, false) if err := m.run(db, false); err != nil {
return err
}
fmt.Println("OK ", filepath.Base(m.Source))
return nil
} }
func (m *Migration) run(db *sql.DB, direction bool) error { func (m *Migration) run(db *sql.DB, direction bool) error {
@ -70,9 +78,13 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
return err return err
} }
} }
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return err
} }
fmt.Println("OK ", filepath.Base(m.Source)) return tx.Commit()
}
return nil return nil
} }
@ -100,29 +112,3 @@ func NumericComponent(name string) (int64, error) {
return n, e return n, e
} }
// FinalizeMigration updates the version table for the given migration,
// and finalize the transaction.
func FinalizeMigrationTx(tx *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := GetDialect().insertVersionSQL()
if _, err := tx.Exec(stmt, v, direction); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
// Update the version table for the given migration without a transaction.
func FinalizeMigration(db *sql.DB, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := GetDialect().insertVersionSQL()
if _, err := db.Exec(stmt, v, direction); err != nil {
return err
}
return nil
}

View File

@ -7,8 +7,6 @@ import (
"io" "io"
"log" "log"
"os" "os"
"path/filepath"
"regexp"
"strings" "strings"
) )
@ -42,8 +40,7 @@ func endsWithSemicolon(line string) bool {
// within a statement. For these cases, we provide the explicit annotations // within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to // 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons. // tell us to ignore semicolons.
func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { func getSQLStatements(r io.Reader, direction bool) (stmts []string, tx bool) {
var buf bytes.Buffer var buf bytes.Buffer
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
@ -55,6 +52,7 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
statementEnded := false statementEnded := false
ignoreSemicolons := false ignoreSemicolons := false
directionIsActive := false directionIsActive := false
tx = true
for scanner.Scan() { for scanner.Scan() {
@ -86,6 +84,10 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
ignoreSemicolons = false ignoreSemicolons = false
} }
break break
case "NO TRANSACTION":
tx = false
break
} }
} }
@ -128,29 +130,6 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
return return
} }
func useTransactions(scriptFile string) bool {
f, err := os.Open(scriptFile)
if err != nil {
log.Fatal(err)
}
noTransactionsRegex, _ := regexp.Compile("--\\s\\+goose\\sNO\\sTRANSACTION")
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if noTransactionsRegex.MatchString(line) {
f.Close()
return false
}
}
f.Close()
return true
}
// Run a migration specified in raw SQL. // Run a migration specified in raw SQL.
// //
// Sections of the script can be annotated with a special comment, // Sections of the script can be annotated with a special comment,
@ -160,71 +139,45 @@ func useTransactions(scriptFile string) bool {
// All statements following an Up or Down directive are grouped together // All statements following an Up or Down directive are grouped together
// until another direction directive is found. // until another direction directive is found.
func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error { func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error {
filePath := filepath.Base(scriptFile)
useTx := useTransactions(scriptFile)
f, err := os.Open(scriptFile) f, err := os.Open(scriptFile)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer f.Close()
if err != nil { statements, useTx := getSQLStatements(f, direction)
log.Fatal(err) log.Printf("gonna use TX: %v\n", useTx)
}
if useTx { if useTx {
err := runMigrationInTransaction(db, f, v, direction, filePath) // TRANSACTION.
if err != nil {
log.Fatalf("FAIL (tx) %s (%v), quitting migration.", filePath, err)
}
} else {
err = runMigrationWithoutTransaction(db, f, v, direction, filePath)
if err != nil {
log.Fatalf("FAIL (no tx) %s (%v), quitting migration.", filePath, err)
}
}
f.Close() tx, err := db.Begin()
return nil
}
// Run the migration within a transaction (recommended)
func runMigrationInTransaction(db *sql.DB, r io.Reader, v int64, direction bool, filePath string) error {
txn, err := db.Begin()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// find each statement, checking annotations for up/down direction for _, query := range statements {
// Commits the transaction if successfully applied each statement and if _, err = tx.Exec(query); err != nil {
// records the version into the version table or returns an error and tx.Rollback()
// rolls back the transaction.
for _, query := range splitSQLStatements(r, direction) {
if _, err = txn.Exec(query); err != nil {
txn.Rollback()
return err return err
} }
} }
if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
if err = FinalizeMigrationTx(txn, direction, v); err != nil { tx.Rollback()
log.Fatalf("error finalizing migration %s, quitting. (%v)", filePath, err) return err
} }
return nil return tx.Commit()
} }
func runMigrationWithoutTransaction(db *sql.DB, r io.Reader, v int64, direction bool, filePath string) error { // NO TRANSACTION.
// find each statement, checking annotations for up/down direction for _, query := range statements {
// Tecords the version into the version table or returns an error
for _, query := range splitSQLStatements(r, direction) {
if _, err := db.Exec(query); err != nil { if _, err := db.Exec(query); err != nil {
return err return err
} }
} }
if _, err := db.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
if err := FinalizeMigration(db, direction, v); err != nil { return err
log.Fatalf("error finalizing migration %s, quitting. (%v)", filePath, err)
} }
return nil return nil