sql migrations: add new annotation to accommodate SQL statements that may have semicolons within them, such as functions. some basic tests as well. fixes #4

pull/2/head
Liam Staskawicz 2013-09-27 12:17:17 -07:00
parent 5dfd15ece2
commit cb801ded9c
2 changed files with 260 additions and 36 deletions

View File

@ -1,13 +1,123 @@
package main package main
import ( import (
"bufio"
"bytes"
"database/sql" "database/sql"
"io/ioutil" "io"
"log" "log"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
) )
const sqlCmdPrefix = "-- +goose "
func endsWithSemicolon(line string) bool {
prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Split(bufio.ScanWords)
for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}
return strings.HasSuffix(prev, ";")
}
// Split the given sql script into individual statements.
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
var buf bytes.Buffer
scanner := bufio.NewScanner(r)
// track the count of each section
// so we can diagnose scripts with no annotations
upSections := 0
downSections := 0
statementEnded := false
ignoreSemicolons := false
directionIsActive := false
for scanner.Scan() {
line := scanner.Text()
// handle any goose-specific commands
if strings.HasPrefix(line, sqlCmdPrefix) {
cmd := strings.TrimSpace(line[len(sqlCmdPrefix):])
switch cmd {
case "Up":
directionIsActive = (direction == true)
upSections++
break
case "Down":
directionIsActive = (direction == false)
downSections++
break
case "StatementBegin":
if directionIsActive {
ignoreSemicolons = true
}
break
case "StatementEnd":
if directionIsActive {
statementEnded = (ignoreSemicolons == true)
ignoreSemicolons = false
}
break
}
}
if !directionIsActive {
continue
}
if _, err := buf.WriteString(line + "\n"); err != nil {
log.Fatalf("io err", err)
}
if !ignoreSemicolons && (statementEnded || endsWithSemicolon(line)) {
statementEnded = false
stmts = append(stmts, buf.String())
buf.Reset()
}
}
if err := scanner.Err(); err != nil {
log.Fatalf("scanning migration:", err)
}
// diagnose likely migration script errors
if ignoreSemicolons {
log.Println("WARNING: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'")
}
if upSections == 0 && downSections == 0 {
log.Fatalf(`ERROR: no Up/Down annotations found, so no statements were executed.
See https://bitbucket.org/liamstask/goose/overview for details.`)
}
return
}
// Run a migration specified in raw SQL. // Run a migration specified in raw SQL.
// //
// Sections of the script can be annotated with a special comment, // Sections of the script can be annotated with a special comment,
@ -23,40 +133,14 @@ func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction
log.Fatal("db.Begin:", err) log.Fatal("db.Begin:", err)
} }
f, err := ioutil.ReadFile(script) f, err := os.Open(script)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// track the count of each section
// so we can diagnose scripts with no annotations
upSections := 0
downSections := 0
// ensure we don't apply a query until we're sure it's going
// in the direction we're interested in
directionIsActive := false
// find each statement, checking annotations for up/down direction // find each statement, checking annotations for up/down direction
// and execute each of them in the current transaction // and execute each of them in the current transaction
stmts := strings.Split(string(f), ";") for _, query := range splitSQLStatements(f, direction) {
for _, query := range stmts {
query = strings.TrimSpace(query)
if strings.HasPrefix(query, "-- +goose Up") {
directionIsActive = direction == true
upSections++
} else if strings.HasPrefix(query, "-- +goose Down") {
directionIsActive = direction == false
downSections++
}
if !directionIsActive || query == "" {
continue
}
if _, err = txn.Exec(query); err != nil { if _, err = txn.Exec(query); err != nil {
txn.Rollback() txn.Rollback()
log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(script), err) log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(script), err)
@ -64,13 +148,6 @@ func runSQLMigration(conf *DBConf, db *sql.DB, script string, v int64, direction
} }
} }
if upSections == 0 && downSections == 0 {
txn.Rollback()
log.Fatalf(`ERROR: no Up/Down annotations found in %s, so no statements were executed.
See https://bitbucket.org/liamstask/goose/overview for details.`,
filepath.Base(script))
}
if err = finalizeMigration(conf, txn, direction, v); err != nil { if err = finalizeMigration(conf, txn, direction, v); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err) log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(script), err)
} }

147
migration_sql_test.go Normal file
View File

@ -0,0 +1,147 @@
package main
import (
"strings"
"testing"
)
func TestSemicolons(t *testing.T) {
type testData struct {
line string
result bool
}
tests := []testData{
{
line: "END;",
result: true,
},
{
line: "END; -- comment",
result: true,
},
{
line: "END ; -- comment",
result: true,
},
{
line: "END -- comment",
result: false,
},
{
line: "END -- comment ;",
result: false,
},
{
line: "END \" ; \" -- comment",
result: false,
},
}
for _, test := range tests {
r := endsWithSemicolon(test.line)
if r != test.result {
t.Errorf("incorrect semicolon. got %v, want %v", r, test.result)
}
}
}
func TestSplitStatements(t *testing.T) {
type testData struct {
sql string
direction bool
count int
}
tests := []testData{
{
sql: functxt,
direction: true,
count: 2,
},
{
sql: functxt,
direction: false,
count: 2,
},
{
sql: multitxt,
direction: true,
count: 2,
},
{
sql: multitxt,
direction: false,
count: 2,
},
}
for _, test := range tests {
stmts := splitSQLStatements(strings.NewReader(test.sql), test.direction)
if len(stmts) != test.count {
t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count)
}
}
}
var functxt = `-- +goose Up
CREATE TABLE IF NOT EXISTS histories (
id BIGSERIAL PRIMARY KEY,
current_value varchar(2000) NOT NULL,
created_at timestamp with time zone NOT NULL
);
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE )
returns void AS $$
DECLARE
create_query text;
BEGIN
FOR create_query IN SELECT
'CREATE TABLE IF NOT EXISTS histories_'
|| TO_CHAR( d, 'YYYY_MM' )
|| ' ( CHECK( created_at >= timestamp '''
|| TO_CHAR( d, 'YYYY-MM-DD 00:00:00' )
|| ''' AND created_at < timestamp '''
|| TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' )
|| ''' ) ) inherits ( histories );'
FROM generate_series( $1, $2, '1 month' ) AS d
LOOP
EXECUTE create_query;
END LOOP; -- LOOP END
END; -- FUNCTION END
$$
language plpgsql;
-- +goose StatementEnd
-- +goose Down
drop function histories_partition_creation(DATE, DATE);
drop TABLE histories;
`
// test multiple up/down transitions in a single script
var multitxt = `-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE post;
-- +goose Up
CREATE TABLE fancier_post (
id int NOT NULL,
title text,
body text,
created_on timestamp without time zone,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE fancier_post;
`