Refactor; make the new SQL parser build

pull/151/head
Vojtech Vitek 2019-03-05 01:12:32 -05:00
parent 456f34d42d
commit 94c2f51496
4 changed files with 45 additions and 33 deletions

View File

@ -3,6 +3,7 @@ package goose
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
@ -38,7 +39,6 @@ func (m *Migration) Up(db *sql.DB) error {
if err := m.run(db, true); err != nil {
return err
}
log.Println("OK ", filepath.Base(m.Source))
return nil
}
@ -47,51 +47,75 @@ func (m *Migration) Down(db *sql.DB) error {
if err := m.run(db, false); err != nil {
return err
}
log.Println("OK ", filepath.Base(m.Source))
return nil
}
func (m *Migration) run(db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return errors.Wrapf(err, "failed to run SQL migration %q", filepath.Base(m.Source))
f, err := os.Open(m.Source)
if err != nil {
return errors.Wrapf(err, "ERROR %v: failed to open SQL migration file", filepath.Base(m.Source))
}
defer f.Close()
statements, useTx, err := parseSQLMigrationFile(f, direction)
if err != nil {
return errors.Wrapf(err, "ERROR %v: failed to parse SQL migration file", filepath.Base(m.Source))
}
if err := runSQLMigration(db, statements, useTx, m.Version, direction); err != nil {
return errors.Wrapf(err, "ERROR %v: failed to run SQL migration", filepath.Base(m.Source))
}
if len(statements) > 0 {
log.Println("OK ", filepath.Base(m.Source))
} else {
log.Println("EMPTY", filepath.Base(m.Source))
}
case ".go":
if !m.Registered {
return errors.Errorf("failed to run Go migration %q: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
return errors.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
tx, err := db.Begin()
if err != nil {
return errors.Wrap(err, "failed to begin transaction")
return errors.Wrap(err, "ERROR failed to begin transaction")
}
fn := m.UpFn
if !direction {
fn = m.DownFn
}
if fn != nil {
// Run Go migration function.
if err := fn(tx); err != nil {
tx.Rollback()
return errors.Wrapf(err, "failed to run Go migration %q", filepath.Base(m.Source))
return errors.Wrapf(err, "ERROR %v: failed to run Go migration function %T", filepath.Base(m.Source), fn)
}
}
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return errors.Wrap(err, "failed to execute transaction")
return errors.Wrap(err, "ERROR failed to execute transaction")
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
tx.Rollback()
return errors.Wrap(err, "failed to execute transaction")
return errors.Wrap(err, "ERROR failed to execute transaction")
}
}
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction")
return errors.Wrap(err, "ERROR failed to commit transaction")
}
if fn != nil {
log.Println("OK ", filepath.Base(m.Source))
} else {
log.Println("EMPTY", filepath.Base(m.Source))
}
return nil

View File

@ -2,7 +2,6 @@ package goose
import (
"database/sql"
"os"
"regexp"
"github.com/pkg/errors"
@ -16,22 +15,11 @@ import (
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error {
f, err := os.Open(sqlFile)
if err != nil {
return errors.Wrap(err, "failed to open SQL migration file")
}
defer f.Close()
statements, useTx, err := parseSQLMigrationFile(f, direction)
if err != nil {
return errors.Wrap(err, "failed to parse SQL migration file")
}
func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool) error {
if useTx {
// TRANSACTION.
printInfo("Begin transaction\n")
verboseInfo("Begin transaction\n")
tx, err := db.Begin()
if err != nil {
@ -39,9 +27,9 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error
}
for _, query := range statements {
printInfo("Executing statement: %s\n", clearStatement(query))
verboseInfo("Executing statement: %s\n", clearStatement(query))
if _, err = tx.Exec(query); err != nil {
printInfo("Rollback transaction\n")
verboseInfo("Rollback transaction\n")
tx.Rollback()
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
}
@ -49,19 +37,19 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
printInfo("Rollback transaction\n")
verboseInfo("Rollback transaction\n")
tx.Rollback()
return errors.Wrap(err, "failed to insert new goose version")
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), v); err != nil {
printInfo("Rollback transaction\n")
verboseInfo("Rollback transaction\n")
tx.Rollback()
return errors.Wrap(err, "failed to delete goose version")
}
}
printInfo("Commit transaction\n")
verboseInfo("Commit transaction\n")
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction")
}
@ -71,7 +59,7 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error
// NO TRANSACTION.
for _, query := range statements {
printInfo("Executing statement: %s\n", clearStatement(query))
verboseInfo("Executing statement: %s\n", clearStatement(query))
if _, err := db.Exec(query); err != nil {
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
}
@ -83,7 +71,7 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error
return nil
}
func printInfo(s string, args ...interface{}) {
func verboseInfo(s string, args ...interface{}) {
if verbose {
log.Printf(s, args...)
}
@ -91,10 +79,10 @@ func printInfo(s string, args ...interface{}) {
var (
matchSQLComments = regexp.MustCompile(`(?m)^--.*$[\r\n]*`)
matchEmptyLines = regexp.MustCompile(`(?m)^$[\r\n]*`) // TODO: Duplicate
matchEmptyEOL = regexp.MustCompile(`(?m)^$[\r\n]*`) // TODO: Duplicate
)
func clearStatement(s string) string {
s = matchSQLComments.ReplaceAllString(s, ``)
return matchEmptyLines.ReplaceAllString(s, ``)
return matchEmptyEOL.ReplaceAllString(s, ``)
}