diff --git a/migration_sql.go b/migration_sql.go index d2a2f2e..df531ba 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -1,13 +1,123 @@ package main import ( + "bufio" + "bytes" "database/sql" - "io/ioutil" + "io" "log" + "os" "path/filepath" "strings" ) +const sqlCmdPrefix = "-- +goose " + +func endsWithSemicolon(line string) bool { + + prev := "" + scanner := bufio.NewScanner(strings.NewReader(line)) + scanner.Split(bufio.ScanWords) + + for scanner.Scan() { + word := scanner.Text() + if strings.HasPrefix(word, "--") { + break + } + prev = word + } + + return strings.HasSuffix(prev, ";") +} + +// Split the given sql script into individual statements. +// +// The base case is to simply split on semicolons, as these +// naturally terminate a statement. +// +// However, more complex cases like pl/pgsql can have semicolons +// within a statement. For these cases, we provide the explicit annotations +// 'StatementBegin' and 'StatementEnd' to allow the script to +// tell us to ignore semicolons. +func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { + + var buf bytes.Buffer + scanner := bufio.NewScanner(r) + + // track the count of each section + // so we can diagnose scripts with no annotations + upSections := 0 + downSections := 0 + + statementEnded := false + ignoreSemicolons := false + directionIsActive := false + + for scanner.Scan() { + + line := scanner.Text() + + // handle any goose-specific commands + if strings.HasPrefix(line, sqlCmdPrefix) { + cmd := strings.TrimSpace(line[len(sqlCmdPrefix):]) + switch cmd { + case "Up": + directionIsActive = (direction == true) + upSections++ + break + + case "Down": + directionIsActive = (direction == false) + downSections++ + break + + case "StatementBegin": + if directionIsActive { + ignoreSemicolons = true + } + break + + case "StatementEnd": + if directionIsActive { + statementEnded = (ignoreSemicolons == true) + ignoreSemicolons = false + } + break + } + } + + if !directionIsActive { + continue + } + + if _, err := buf.WriteString(line + "\n"); err != nil { + log.Fatalf("io err", err) + } + + if !ignoreSemicolons && (statementEnded || endsWithSemicolon(line)) { + statementEnded = false + stmts = append(stmts, buf.String()) + buf.Reset() + } + } + + if err := scanner.Err(); err != nil { + log.Fatalf("scanning migration:", err) + } + + // diagnose likely migration script errors + if ignoreSemicolons { + log.Println("WARNING: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'") + } + + if upSections == 0 && downSections == 0 { + log.Fatalf(`ERROR: no Up/Down annotations found, so no statements were executed. + See https://bitbucket.org/liamstask/goose/overview for details.`) + } + + return +} + // Run a migration specified in raw SQL. // // Sections of the script can be annotated with a special comment, @@ -23,40 +133,14 @@ func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction log.Fatal("db.Begin:", err) } - f, err := ioutil.ReadFile(script) + f, err := os.Open(script) if err != nil { log.Fatal(err) } - // track the count of each section - // so we can diagnose scripts with no annotations - upSections := 0 - downSections := 0 - - // ensure we don't apply a query until we're sure it's going - // in the direction we're interested in - directionIsActive := false - // find each statement, checking annotations for up/down direction // and execute each of them in the current transaction - stmts := strings.Split(string(f), ";") - - for _, query := range stmts { - - query = strings.TrimSpace(query) - - if strings.HasPrefix(query, "-- +goose Up") { - directionIsActive = direction == true - upSections++ - } else if strings.HasPrefix(query, "-- +goose Down") { - directionIsActive = direction == false - downSections++ - } - - if !directionIsActive || query == "" { - continue - } - + for _, query := range splitSQLStatements(f, direction) { if _, err = txn.Exec(query); err != nil { txn.Rollback() log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(script), err) @@ -64,13 +148,6 @@ func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction } } - if upSections == 0 && downSections == 0 { - txn.Rollback() - log.Fatalf(`ERROR: no Up/Down annotations found in %s, so no statements were executed. - See https://bitbucket.org/liamstask/goose/overview for details.`, - filepath.Base(script)) - } - if err = finalizeMigration(conf, txn, direction, v); err != nil { log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err) } diff --git a/migration_sql_test.go b/migration_sql_test.go new file mode 100644 index 0000000..48c34cb --- /dev/null +++ b/migration_sql_test.go @@ -0,0 +1,147 @@ +package main + +import ( + "strings" + "testing" +) + +func TestSemicolons(t *testing.T) { + + type testData struct { + line string + result bool + } + + tests := []testData{ + { + line: "END;", + result: true, + }, + { + line: "END; -- comment", + result: true, + }, + { + line: "END ; -- comment", + result: true, + }, + { + line: "END -- comment", + result: false, + }, + { + line: "END -- comment ;", + result: false, + }, + { + line: "END \" ; \" -- comment", + result: false, + }, + } + + for _, test := range tests { + r := endsWithSemicolon(test.line) + if r != test.result { + t.Errorf("incorrect semicolon. got %v, want %v", r, test.result) + } + } +} + +func TestSplitStatements(t *testing.T) { + + type testData struct { + sql string + direction bool + count int + } + + tests := []testData{ + { + sql: functxt, + direction: true, + count: 2, + }, + { + sql: functxt, + direction: false, + count: 2, + }, + { + sql: multitxt, + direction: true, + count: 2, + }, + { + sql: multitxt, + direction: false, + count: 2, + }, + } + + for _, test := range tests { + stmts := splitSQLStatements(strings.NewReader(test.sql), test.direction) + if len(stmts) != test.count { + t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count) + } + } +} + +var functxt = `-- +goose Up +CREATE TABLE IF NOT EXISTS histories ( + id BIGSERIAL PRIMARY KEY, + current_value varchar(2000) NOT NULL, + created_at timestamp with time zone NOT NULL +); + +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) +returns void AS $$ +DECLARE + create_query text; +BEGIN + FOR create_query IN SELECT + 'CREATE TABLE IF NOT EXISTS histories_' + || TO_CHAR( d, 'YYYY_MM' ) + || ' ( CHECK( created_at >= timestamp ''' + || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) + || ''' AND created_at < timestamp ''' + || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) + || ''' ) ) inherits ( histories );' + FROM generate_series( $1, $2, '1 month' ) AS d + LOOP + EXECUTE create_query; + END LOOP; -- LOOP END +END; -- FUNCTION END +$$ +language plpgsql; +-- +goose StatementEnd + +-- +goose Down +drop function histories_partition_creation(DATE, DATE); +drop TABLE histories; +` + +// test multiple up/down transitions in a single script +var multitxt = `-- +goose Up +CREATE TABLE post ( + id int NOT NULL, + title text, + body text, + PRIMARY KEY(id) +); + +-- +goose Down +DROP TABLE post; + +-- +goose Up +CREATE TABLE fancier_post ( + id int NOT NULL, + title text, + body text, + created_on timestamp without time zone, + PRIMARY KEY(id) +); + +-- +goose Down +DROP TABLE fancier_post; +`