From a26643fb2bfd925ae85fb871c3475c9e44b999e2 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 20 Jun 2017 16:30:29 -0400 Subject: [PATCH] Fix migrations w/o TX, refactor --- migrate.go | 7 +++- migration.go | 46 ++++++++------------- migration_sql.go | 101 +++++++++++++---------------------------------- 3 files changed, 48 insertions(+), 106 deletions(-) diff --git a/migrate.go b/migrate.go index 197f7f2..edbced5 100644 --- a/migrate.go +++ b/migrate.go @@ -98,6 +98,7 @@ func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.T if existing, ok := registeredGoMigrations[v]; ok { panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) } + registeredGoMigrations[v] = migration } @@ -141,12 +142,14 @@ func CollectMigrations(dirpath string, current, target int64) (Migrations, error for _, file := range goMigrationFiles { v, err := NumericComponent(file) if err != nil { - continue // Skip any files that don't have start with version. + continue // Skip any files that don't have version prefix. } - // Skip migrations already registered via goose.AddMigration(). + + // Skip migrations already existing migrations registered via goose.AddMigration(). if _, ok := registeredGoMigrations[v]; ok { continue } + if versionFilter(v, current, target) { migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false} migrations = append(migrations, migration) diff --git a/migration.go b/migration.go index fafd2e3..61774ec 100644 --- a/migration.go +++ b/migration.go @@ -35,12 +35,20 @@ func (m *Migration) String() string { // Up runs an up migration. func (m *Migration) Up(db *sql.DB) error { - return m.run(db, true) + if err := m.run(db, true); err != nil { + return err + } + fmt.Println("OK ", filepath.Base(m.Source)) + return nil } // Down runs a down migration. func (m *Migration) Down(db *sql.DB) error { - return m.run(db, false) + if err := m.run(db, false); err != nil { + return err + } + fmt.Println("OK ", filepath.Base(m.Source)) + return nil } func (m *Migration) run(db *sql.DB, direction bool) error { @@ -70,9 +78,13 @@ func (m *Migration) run(db *sql.DB, direction bool) error { return err } } - } + if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil { + tx.Rollback() + return err + } - fmt.Println("OK ", filepath.Base(m.Source)) + return tx.Commit() + } return nil } @@ -100,29 +112,3 @@ func NumericComponent(name string) (int64, error) { return n, e } - -// FinalizeMigration updates the version table for the given migration, -// and finalize the transaction. -func FinalizeMigrationTx(tx *sql.Tx, direction bool, v int64) error { - - // XXX: drop goose_db_version table on some minimum version number? - stmt := GetDialect().insertVersionSQL() - if _, err := tx.Exec(stmt, v, direction); err != nil { - tx.Rollback() - return err - } - - return tx.Commit() -} - -// Update the version table for the given migration without a transaction. -func FinalizeMigration(db *sql.DB, direction bool, v int64) error { - - // XXX: drop goose_db_version table on some minimum version number? - stmt := GetDialect().insertVersionSQL() - if _, err := db.Exec(stmt, v, direction); err != nil { - return err - } - - return nil -} diff --git a/migration_sql.go b/migration_sql.go index 3b56867..4f1d31d 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -7,8 +7,6 @@ import ( "io" "log" "os" - "path/filepath" - "regexp" "strings" ) @@ -42,8 +40,7 @@ func endsWithSemicolon(line string) bool { // within a statement. For these cases, we provide the explicit annotations // 'StatementBegin' and 'StatementEnd' to allow the script to // tell us to ignore semicolons. -func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { - +func getSQLStatements(r io.Reader, direction bool) (stmts []string, tx bool) { var buf bytes.Buffer scanner := bufio.NewScanner(r) @@ -55,6 +52,7 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { statementEnded := false ignoreSemicolons := false directionIsActive := false + tx = true for scanner.Scan() { @@ -86,6 +84,10 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { ignoreSemicolons = false } break + + case "NO TRANSACTION": + tx = false + break } } @@ -128,29 +130,6 @@ func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { return } -func useTransactions(scriptFile string) bool { - f, err := os.Open(scriptFile) - if err != nil { - log.Fatal(err) - } - - noTransactionsRegex, _ := regexp.Compile("--\\s\\+goose\\sNO\\sTRANSACTION") - - scanner := bufio.NewScanner(f) - - for scanner.Scan() { - line := scanner.Text() - - if noTransactionsRegex.MatchString(line) { - f.Close() - return false - } - } - - f.Close() - return true -} - // Run a migration specified in raw SQL. // // Sections of the script can be annotated with a special comment, @@ -160,71 +139,45 @@ func useTransactions(scriptFile string) bool { // All statements following an Up or Down directive are grouped together // until another direction directive is found. func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error { - filePath := filepath.Base(scriptFile) - useTx := useTransactions(scriptFile) - f, err := os.Open(scriptFile) if err != nil { log.Fatal(err) } + defer f.Close() - if err != nil { - log.Fatal(err) - } + statements, useTx := getSQLStatements(f, direction) + log.Printf("gonna use TX: %v\n", useTx) if useTx { - err := runMigrationInTransaction(db, f, v, direction, filePath) + // TRANSACTION. + + tx, err := db.Begin() if err != nil { - log.Fatalf("FAIL (tx) %s (%v), quitting migration.", filePath, err) + log.Fatal(err) } - } else { - err = runMigrationWithoutTransaction(db, f, v, direction, filePath) - if err != nil { - log.Fatalf("FAIL (no tx) %s (%v), quitting migration.", filePath, err) + + for _, query := range statements { + if _, err = tx.Exec(query); err != nil { + tx.Rollback() + return err + } } - } - - f.Close() - - return nil -} - -// Run the migration within a transaction (recommended) -func runMigrationInTransaction(db *sql.DB, r io.Reader, v int64, direction bool, filePath string) error { - txn, err := db.Begin() - if err != nil { - log.Fatal(err) - } - - // find each statement, checking annotations for up/down direction - // Commits the transaction if successfully applied each statement and - // records the version into the version table or returns an error and - // rolls back the transaction. - for _, query := range splitSQLStatements(r, direction) { - if _, err = txn.Exec(query); err != nil { - txn.Rollback() + if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil { + tx.Rollback() return err } + + return tx.Commit() } - if err = FinalizeMigrationTx(txn, direction, v); err != nil { - log.Fatalf("error finalizing migration %s, quitting. (%v)", filePath, err) - } - - return nil -} - -func runMigrationWithoutTransaction(db *sql.DB, r io.Reader, v int64, direction bool, filePath string) error { - // find each statement, checking annotations for up/down direction - // Tecords the version into the version table or returns an error - for _, query := range splitSQLStatements(r, direction) { + // NO TRANSACTION. + for _, query := range statements { if _, err := db.Exec(query); err != nil { return err } } - - if err := FinalizeMigration(db, direction, v); err != nil { - log.Fatalf("error finalizing migration %s, quitting. (%v)", filePath, err) + if _, err := db.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil { + return err } return nil