Kick off new SQL parser

pull/151/head
Vojtech Vitek 2019-03-05 01:11:22 -05:00
parent f81c971ff2
commit 456f34d42d
3 changed files with 235 additions and 179 deletions

View File

@ -1,153 +1,13 @@
package goose
import (
"bufio"
"bytes"
"database/sql"
"fmt"
"io"
"os"
"regexp"
"strings"
"sync"
"github.com/pkg/errors"
)
const sqlCmdPrefix = "-- +goose "
const scanBufSize = 4 * 1024 * 1024
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, scanBufSize)
},
}
// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)
prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Buffer(scanBuf, scanBufSize)
scanner.Split(bufio.ScanWords)
for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}
return strings.HasSuffix(prev, ";")
}
// Split the given sql script into individual statements.
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func getSQLStatements(r io.Reader, direction bool) ([]string, bool, error) {
var buf bytes.Buffer
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)
scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)
// track the count of each section
// so we can diagnose scripts with no annotations
upSections := 0
downSections := 0
statementEnded := false
ignoreSemicolons := false
directionIsActive := false
tx := true
stmts := []string{}
for scanner.Scan() {
line := scanner.Text()
// handle any goose-specific commands
if strings.HasPrefix(line, sqlCmdPrefix) {
cmd := strings.TrimSpace(line[len(sqlCmdPrefix):])
switch cmd {
case "Up":
directionIsActive = (direction == true)
upSections++
break
case "Down":
directionIsActive = (direction == false)
downSections++
break
case "StatementBegin":
if directionIsActive {
ignoreSemicolons = true
}
break
case "StatementEnd":
if directionIsActive {
statementEnded = (ignoreSemicolons == true)
ignoreSemicolons = false
}
break
case "NO TRANSACTION":
tx = false
break
}
}
if !directionIsActive {
continue
}
if _, err := buf.WriteString(line + "\n"); err != nil {
return nil, false, fmt.Errorf("io err: %v", err)
}
// Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement
// Lines that end with semicolon that are in a statement block
// do not conclude statement.
if (!ignoreSemicolons && endsWithSemicolon(line)) || statementEnded {
statementEnded = false
stmts = append(stmts, buf.String())
buf.Reset()
}
}
if err := scanner.Err(); err != nil {
return nil, false, fmt.Errorf("scanning migration: %v", err)
}
// diagnose likely migration script errors
if ignoreSemicolons {
return nil, false, fmt.Errorf("parsing migration: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'")
}
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
return nil, false, fmt.Errorf("parsing migration: unexpected unfinished SQL query: %s. potential missing semicolon", bufferRemaining)
}
if upSections == 0 && downSections == 0 {
return nil, false, fmt.Errorf("parsing migration: no Up/Down annotations found, so no statements were executed. See https://bitbucket.org/liamstask/goose/overview for details")
}
return stmts, tx, nil
}
// Run a migration specified in raw SQL.
//
// Sections of the script can be annotated with a special comment,
@ -163,9 +23,9 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error
}
defer f.Close()
statements, useTx, err := getSQLStatements(f, direction)
statements, useTx, err := parseSQLMigrationFile(f, direction)
if err != nil {
return err
return errors.Wrap(err, "failed to parse SQL migration file")
}
if useTx {
@ -231,7 +91,7 @@ func printInfo(s string, args ...interface{}) {
var (
matchSQLComments = regexp.MustCompile(`(?m)^--.*$[\r\n]*`)
matchEmptyLines = regexp.MustCompile(`(?m)^$[\r\n]*`)
matchEmptyLines = regexp.MustCompile(`(?m)^$[\r\n]*`) // TODO: Duplicate
)
func clearStatement(s string) string {

View File

@ -83,48 +83,45 @@ func TestUseTransactions(t *testing.T) {
}
func TestParsingErrors(t *testing.T) {
type testData struct {
sql string
error bool
tt := []string{
statementBeginNoStatementEnd,
unfinishedSQL,
noUpDownAnnotations,
emptySQL,
}
tests := []testData{
{sql: statementBeginNoStatementEnd, error: true},
{sql: unfinishedSQL, error: true},
{sql: noUpDownAnnotations, error: true},
}
for _, test := range tests {
_, _, err := getSQLStatements(strings.NewReader(test.sql), true)
for _, sql := range tt {
_, _, err := getSQLStatements(strings.NewReader(sql), true)
if err == nil {
t.Errorf("Failed transaction check. got %v, want %v", err, test.error)
t.Errorf("expected error on %q", sql)
}
}
}
var functxt = `-- +goose Up
CREATE TABLE IF NOT EXISTS histories (
id BIGSERIAL PRIMARY KEY,
current_value varchar(2000) NOT NULL,
created_at timestamp with time zone NOT NULL
id BIGSERIAL PRIMARY KEY,
current_value varchar(2000) NOT NULL,
created_at timestamp with time zone NOT NULL
);
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE )
returns void AS $$
DECLARE
create_query text;
create_query text;
BEGIN
FOR create_query IN SELECT
'CREATE TABLE IF NOT EXISTS histories_'
|| TO_CHAR( d, 'YYYY_MM' )
|| ' ( CHECK( created_at >= timestamp '''
|| TO_CHAR( d, 'YYYY-MM-DD 00:00:00' )
|| ''' AND created_at < timestamp '''
|| TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' )
|| ''' ) ) inherits ( histories );'
FROM generate_series( $1, $2, '1 month' ) AS d
LOOP
EXECUTE create_query;
END LOOP; -- LOOP END
FOR create_query IN SELECT
'CREATE TABLE IF NOT EXISTS histories_'
|| TO_CHAR( d, 'YYYY_MM' )
|| ' ( CHECK( created_at >= timestamp '''
|| TO_CHAR( d, 'YYYY-MM-DD 00:00:00' )
|| ''' AND created_at < timestamp '''
|| TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' )
|| ''' ) ) inherits ( histories );'
FROM generate_series( $1, $2, '1 month' ) AS d
LOOP
EXECUTE create_query;
END LOOP; -- LOOP END
END; -- FUNCTION END
$$
language plpgsql;
@ -138,10 +135,10 @@ drop TABLE histories;
// test multiple up/down transitions in a single script
var multitxt = `-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
);
-- +goose Down
@ -149,11 +146,11 @@ DROP TABLE post;
-- +goose Up
CREATE TABLE fancier_post (
id int NOT NULL,
title text,
body text,
created_on timestamp without time zone,
PRIMARY KEY(id)
id int NOT NULL,
title text,
body text,
created_on timestamp without time zone,
PRIMARY KEY(id)
);
-- +goose Down
@ -200,6 +197,10 @@ ALTER TABLE post
-- +goose Down
`
var emptySQL = `-- +goose Up
-- This is just a comment`
var noUpDownAnnotations = `
CREATE TABLE post (
id int NOT NULL,

195
parser.go Normal file
View File

@ -0,0 +1,195 @@
package goose
import (
"bufio"
"bytes"
"io"
"regexp"
"strings"
"sync"
"github.com/pkg/errors"
)
type parserState int
const (
start parserState = iota
gooseUp
gooseStatementBeginUp
gooseStatementEndUp
gooseDown
gooseStatementBeginDown
gooseStatementEndDown
)
const scanBufSize = 4 * 1024 * 1024
var matchEmptyLines = regexp.MustCompile(`^\s*$`)
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, scanBufSize)
},
}
// Split given SQL script into individual statements and return
// SQL statements for given direction (up=true, down=false).
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func parseSQLMigrationFile(r io.Reader, direction bool) (stmts []string, useTx bool, err error) {
var buf bytes.Buffer
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)
scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)
stateMachine := start
useTx = true
for scanner.Scan() {
line := scanner.Text()
const goosePrefix = "-- +goose "
if strings.HasPrefix(line, goosePrefix) {
cmd := strings.TrimSpace(line[len(goosePrefix):])
switch cmd {
case "Up":
switch stateMachine {
case start:
stateMachine = gooseUp
default:
return nil, false, errors.New("failed to parse SQL migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
}
case "Down":
switch stateMachine {
case gooseUp, gooseStatementBeginUp:
stateMachine = gooseDown
default:
return nil, false, errors.New("failed to parse SQL migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
}
case "StatementBegin":
switch stateMachine {
case gooseUp:
stateMachine = gooseStatementBeginUp
case gooseDown:
stateMachine = gooseStatementBeginDown
default:
return nil, false, errors.New("failed to parse SQL migration: '-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, see https://github.com/pressly/goose#sql-migrations")
}
case "StatementEnd":
switch stateMachine {
case gooseStatementBeginUp:
stateMachine = gooseStatementEndUp
case gooseStatementBeginDown:
stateMachine = gooseStatementEndDown
default:
return nil, false, errors.New("failed to parse SQL migration: '-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
}
case "NO TRANSACTION":
useTx = false
default:
return nil, false, errors.Errorf("unknown annotation %q", cmd)
}
}
// Ignore comments.
if strings.HasPrefix(line, `--`) {
continue
}
// Ignore empty lines.
if matchEmptyLines.MatchString(line) {
continue
}
// Write SQL line to a buffer.
if _, err := buf.WriteString(line + "\n"); err != nil {
return nil, false, errors.Wrap(err, "failed to write to buf")
}
// Read SQL body one by line, if we're in the right direction.
//
// 1) basic query with semicolon; 2) psql statement
//
// Export statement once we hit end of statement.
switch stateMachine {
case gooseUp:
if !endsWithSemicolon(line) {
return nil, false, errors.Errorf("failed to parse Up SQL migration: %q: simple query must be terminated by semicolon;", line)
}
if direction { // up
stmts = append(stmts, buf.String())
}
case gooseDown:
if !endsWithSemicolon(line) {
return nil, false, errors.Errorf("failed to parse Down SQL migration: %q: simple query must be terminated by semicolon;", line)
}
if !direction { // down
stmts = append(stmts, buf.String())
}
case gooseStatementEndUp:
if direction /*up*/ && endsWithSemicolon(line) {
stmts = append(stmts, buf.String())
}
case gooseStatementEndDown:
if !direction /*down*/ && endsWithSemicolon(line) {
stmts = append(stmts, buf.String())
}
default:
return nil, false, errors.New("failed to parse migration: unexpected state %q, see https://github.com/pressly/goose#sql-migrations")
}
buf.Reset()
}
if err := scanner.Err(); err != nil {
return nil, false, errors.Wrap(err, "failed to scan migration")
}
// EOF
switch stateMachine {
case start:
return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
case gooseStatementBeginUp, gooseStatementBeginDown:
return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
}
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
return nil, false, errors.Errorf("failed to parse migration: unexpected unfinished SQL query: %q: missing semicolon?", bufferRemaining)
}
return stmts, useTx, nil
}
// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)
prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Buffer(scanBuf, scanBufSize)
scanner.Split(bufio.ScanWords)
for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}
return strings.HasSuffix(prev, ";")
}