Fix SQL parser tests

pull/151/head
Vojtech Vitek 2019-03-05 02:53:57 -05:00
parent 14668d05d8
commit fff58a44df
2 changed files with 115 additions and 72 deletions

View File

@ -14,15 +14,25 @@ import (
type parserState int
const (
start parserState = iota
gooseUp
gooseStatementBeginUp
gooseStatementEndUp
gooseDown
gooseStatementBeginDown
gooseStatementEndDown
start parserState = iota // 0
gooseUp // 1
gooseStatementBeginUp // 2
gooseStatementEndUp // 3
gooseDown // 4
gooseStatementBeginDown // 5
gooseStatementEndDown // 6
)
type stateMachine parserState
func (s *stateMachine) Get() parserState {
return parserState(*s)
}
func (s *stateMachine) Set(new parserState) {
verboseInfo("=> stateMachine: %v => %v", *s, new)
*s = stateMachine(new)
}
const scanBufSize = 4 * 1024 * 1024
var matchEmptyLines = regexp.MustCompile(`^\s*$`)
@ -51,67 +61,66 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool,
scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)
stateMachine := start
stateMachine := stateMachine(start)
useTx = true
for scanner.Scan() {
line := scanner.Text()
const goosePrefix = "-- +goose "
if strings.HasPrefix(line, goosePrefix) {
cmd := strings.TrimSpace(line[len(goosePrefix):])
verboseInfo(" %v\n", line)
if strings.HasPrefix(line, "--") {
cmd := strings.TrimSpace(strings.TrimPrefix(line, "--"))
switch cmd {
case "Up":
switch stateMachine {
case "+goose Up":
switch stateMachine.Get() {
case start:
stateMachine = gooseUp
stateMachine.Set(gooseUp)
default:
return nil, false, errors.New("failed to parse SQL migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
return nil, false, errors.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%v, see https://github.com/pressly/goose#sql-migrations", stateMachine)
}
case "Down":
switch stateMachine {
case gooseUp, gooseStatementBeginUp:
stateMachine = gooseDown
case "+goose Down":
switch stateMachine.Get() {
case gooseUp, gooseStatementEndUp:
stateMachine.Set(gooseDown)
default:
return nil, false, errors.New("failed to parse SQL migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
return nil, false, errors.Errorf("must start with '-- +goose Up' annotation, stateMachine=%v, see https://github.com/pressly/goose#sql-migrations", stateMachine)
}
case "StatementBegin":
switch stateMachine {
case gooseUp:
stateMachine = gooseStatementBeginUp
case gooseDown:
stateMachine = gooseStatementBeginDown
case "+goose StatementBegin":
switch stateMachine.Get() {
case gooseUp, gooseStatementEndUp:
stateMachine.Set(gooseStatementBeginUp)
case gooseDown, gooseStatementEndDown:
stateMachine.Set(gooseStatementBeginDown)
default:
return nil, false, errors.New("failed to parse SQL migration: '-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, see https://github.com/pressly/goose#sql-migrations")
return nil, false, errors.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%v, see https://github.com/pressly/goose#sql-migrations", stateMachine)
}
case "StatementEnd":
switch stateMachine {
case "+goose StatementEnd":
switch stateMachine.Get() {
case gooseStatementBeginUp:
stateMachine = gooseStatementEndUp
stateMachine.Set(gooseStatementEndUp)
case gooseStatementBeginDown:
stateMachine = gooseStatementEndDown
stateMachine.Set(gooseStatementEndDown)
default:
return nil, false, errors.New("failed to parse SQL migration: '-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
}
case "NO TRANSACTION":
case "+goose NO TRANSACTION":
useTx = false
default:
return nil, false, errors.Errorf("unknown annotation %q", cmd)
// Ignore comments.
verboseInfo("=> ignore comment")
}
}
// Ignore comments.
if strings.HasPrefix(line, `--`) {
continue
}
// Ignore empty lines.
if matchEmptyLines.MatchString(line) {
verboseInfo("=> ignore empty line")
continue
}
@ -125,41 +134,44 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool,
// 1) basic query with semicolon; 2) psql statement
//
// Export statement once we hit end of statement.
switch stateMachine {
case gooseUp:
switch stateMachine.Get() {
case gooseUp, gooseStatementBeginUp, gooseStatementEndUp:
if !direction /*down*/ {
buf.Reset()
break
verboseInfo("=> ignore down")
continue
}
case gooseDown, gooseStatementBeginDown, gooseStatementEndDown:
if direction /*up*/ {
buf.Reset()
verboseInfo("=> ignore up")
continue
}
default:
return nil, false, errors.Errorf("failed to parse migration: unexpected state %q on line %q, see https://github.com/pressly/goose#sql-migrations", stateMachine, line)
}
switch stateMachine.Get() {
case gooseUp:
if endsWithSemicolon(line) {
stmts = append(stmts, buf.String())
buf.Reset()
verboseInfo("=> store simple up query")
}
case gooseDown:
if direction /*up*/ {
buf.Reset()
break
}
if endsWithSemicolon(line) {
stmts = append(stmts, buf.String())
buf.Reset()
verboseInfo("=> store simple down query")
}
case gooseStatementEndUp:
if !direction /*down*/ {
buf.Reset()
break
}
stmts = append(stmts, buf.String())
buf.Reset()
verboseInfo("=> store up statement")
case gooseStatementEndDown:
if direction /*up*/ {
buf.Reset()
break
}
stmts = append(stmts, buf.String())
buf.Reset()
default:
return nil, false, errors.New("failed to parse migration: unexpected state %q, see https://github.com/pressly/goose#sql-migrations")
verboseInfo("=> store down statement")
}
}
if err := scanner.Err(); err != nil {
@ -167,7 +179,7 @@ func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool,
}
// EOF
switch stateMachine {
switch stateMachine.Get() {
case start:
return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
case gooseStatementBeginUp, gooseStatementBeginDown:

View File

@ -4,6 +4,8 @@ import (
"os"
"strings"
"testing"
"github.com/pkg/errors"
)
func TestSemicolons(t *testing.T) {
@ -30,26 +32,53 @@ func TestSemicolons(t *testing.T) {
}
func TestSplitStatements(t *testing.T) {
//SetVerbose(true)
type testData struct {
sql string
direction bool
count int
sql string
up int
down int
}
tests := []testData{
{sql: functxt, direction: true, count: 2},
{sql: functxt, direction: false, count: 2},
{sql: multitxt, direction: true, count: 2},
{sql: multitxt, direction: false, count: 2},
tt := []testData{
{sql: `-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
); SELECT 1;
-- comment
SELECT 2;
SELECT 3; SELECT 3;
SELECT 4;
-- +goose Down
-- comment
DROP TABLE post; SELECT 1; -- comment
`, up: 4, down: 1},
{sql: functxt, up: 2, down: 2},
}
for _, test := range tests {
stmts, _, err := parseSQLMigration(strings.NewReader(test.sql), test.direction)
for i, test := range tt {
// up
stmts, _, err := parseSQLMigration(strings.NewReader(test.sql), true)
if err != nil {
t.Error(err)
t.Error(errors.Wrapf(err, "tt[%v] unexpected error", i))
}
if len(stmts) != test.count {
t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count)
if len(stmts) != test.up {
t.Errorf("tt[%v] incorrect number of up stmts. got %v (%+v), want %v", i, len(stmts), stmts, test.up)
}
// down
stmts, _, err = parseSQLMigration(strings.NewReader(test.sql), false)
if err != nil {
t.Error(errors.Wrapf(err, "tt[%v] unexpected error", i))
}
if len(stmts) != test.down {
t.Errorf("tt[%v] incorrect number of down stmts. got %v (%+v), want %v", i, len(stmts), stmts, test.down)
}
}
}
@ -88,6 +117,8 @@ func TestParsingErrors(t *testing.T) {
unfinishedSQL,
noUpDownAnnotations,
emptySQL,
multiUpDown,
downFirst,
}
for _, sql := range tt {
_, _, err := parseSQLMigration(strings.NewReader(sql), true)
@ -132,8 +163,7 @@ drop function histories_partition_creation(DATE, DATE);
drop TABLE histories;
`
// test multiple up/down transitions in a single script
var multitxt = `-- +goose Up
var multiUpDown = `-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
@ -152,8 +182,9 @@ CREATE TABLE fancier_post (
created_on timestamp without time zone,
PRIMARY KEY(id)
);
`
-- +goose Down
var downFirst = `-- +goose Down
DROP TABLE fancier_post;
`