goose/internal/sqlparser/parser.go

395 lines
11 KiB
Go

package sqlparser
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"os"
"strings"
"sync"
"github.com/mfridman/interpolate"
)
type Direction string
const (
DirectionUp Direction = "up"
DirectionDown Direction = "down"
)
func FromBool(b bool) Direction {
if b {
return DirectionUp
}
return DirectionDown
}
func (d Direction) String() string {
return string(d)
}
func (d Direction) ToBool() bool {
return d == DirectionUp
}
type parserState int
const (
start parserState = iota // 0
gooseUp // 1
gooseStatementBeginUp // 2
gooseStatementEndUp // 3
gooseDown // 4
gooseStatementBeginDown // 5
gooseStatementEndDown // 6
)
type stateMachine struct {
state parserState
verbose bool
}
func newStateMachine(begin parserState, verbose bool) *stateMachine {
return &stateMachine{
state: begin,
verbose: verbose,
}
}
func (s *stateMachine) get() parserState {
return s.state
}
func (s *stateMachine) set(new parserState) {
s.print("set %d => %d", s.state, new)
s.state = new
}
const (
grayColor = "\033[90m"
resetColor = "\033[00m"
)
func (s *stateMachine) print(msg string, args ...interface{}) {
msg = "StateMachine: " + msg
if s.verbose {
log.Printf(grayColor+msg+resetColor, args...)
}
}
const scanBufSize = 4 * 1024 * 1024
var bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, scanBufSize)
return &buf
},
}
// Split given SQL script into individual statements and return
// SQL statements for given direction (up=true, down=false).
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []string, useTx bool, err error) {
scanBufPtr := bufferPool.Get().(*[]byte)
scanBuf := *scanBufPtr
defer bufferPool.Put(scanBufPtr)
scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)
stateMachine := newStateMachine(start, debug)
useTx = true
useEnvsub := false
var buf bytes.Buffer
for scanner.Scan() {
line := scanner.Text()
if debug {
log.Println(line)
}
if stateMachine.get() == start && strings.TrimSpace(line) == "" {
continue
}
// Check for annotations.
// All annotations must be in format: "-- +goose [annotation]"
if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") {
var cmd annotation
cmd, err = extractAnnotation(line)
if err != nil {
return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err)
}
switch cmd {
case annotationUp:
switch stateMachine.get() {
case start:
stateMachine.set(gooseUp)
default:
return nil, false, fmt.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
}
continue
case annotationDown:
switch stateMachine.get() {
case gooseUp, gooseStatementEndUp:
// If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a
// previous up annotation. This is an error, because we expect the SQL query to be terminated by a semicolon
// and the buffer to have been reset.
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
}
stateMachine.set(gooseDown)
default:
return nil, false, fmt.Errorf("must start with '-- +goose Up' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
}
continue
case annotationStatementBegin:
switch stateMachine.get() {
case gooseUp, gooseStatementEndUp:
stateMachine.set(gooseStatementBeginUp)
case gooseDown, gooseStatementEndDown:
stateMachine.set(gooseStatementBeginDown)
default:
return nil, false, fmt.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
}
continue
case annotationStatementEnd:
switch stateMachine.get() {
case gooseStatementBeginUp:
stateMachine.set(gooseStatementEndUp)
case gooseStatementBeginDown:
stateMachine.set(gooseStatementEndDown)
default:
return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
}
case annotationNoTransaction:
useTx = false
continue
case annotationEnvsubOn:
useEnvsub = true
continue
case annotationEnvsubOff:
useEnvsub = false
continue
default:
return nil, false, fmt.Errorf("unknown annotation: %q", cmd)
}
}
// Once we've started parsing a statement the buffer is no longer empty,
// we keep all comments up until the end of the statement (the buffer will be reset).
// All other comments in the file are ignored.
if buf.Len() == 0 {
// This check ensures leading comments and empty lines prior to a statement are ignored.
if strings.HasPrefix(strings.TrimSpace(line), "--") || line == "" {
stateMachine.print("ignore comment")
continue
}
}
switch stateMachine.get() {
case gooseStatementEndDown, gooseStatementEndUp:
// Do not include the "+goose StatementEnd" annotation in the final statement.
default:
if useEnvsub {
expanded, err := interpolate.Interpolate(&envWrapper{}, line)
if err != nil {
return nil, false, fmt.Errorf("variable substitution failed: %w:\n%s", err, line)
}
line = expanded
}
// Write SQL line to a buffer.
if _, err := buf.WriteString(line + "\n"); err != nil {
return nil, false, fmt.Errorf("failed to write to buf: %w", err)
}
}
// Read SQL body one by line, if we're in the right direction.
//
// 1) basic query with semicolon; 2) psql statement
//
// Export statement once we hit end of statement.
switch stateMachine.get() {
case gooseUp, gooseStatementBeginUp, gooseStatementEndUp:
if direction == DirectionDown {
buf.Reset()
stateMachine.print("ignore down")
continue
}
case gooseDown, gooseStatementBeginDown, gooseStatementEndDown:
if direction == DirectionUp {
buf.Reset()
stateMachine.print("ignore up")
continue
}
default:
return nil, false, fmt.Errorf("failed to parse migration: unexpected state %d on line %q, see https://github.com/pressly/goose#sql-migrations", stateMachine.state, line)
}
switch stateMachine.get() {
case gooseUp:
if endsWithSemicolon(line) {
stmts = append(stmts, cleanupStatement(buf.String()))
buf.Reset()
stateMachine.print("store simple Up query")
}
case gooseDown:
if endsWithSemicolon(line) {
stmts = append(stmts, cleanupStatement(buf.String()))
buf.Reset()
stateMachine.print("store simple Down query")
}
case gooseStatementEndUp:
stmts = append(stmts, cleanupStatement(buf.String()))
buf.Reset()
stateMachine.print("store Up statement")
stateMachine.set(gooseUp)
case gooseStatementEndDown:
stmts = append(stmts, cleanupStatement(buf.String()))
buf.Reset()
stateMachine.print("store Down statement")
stateMachine.set(gooseDown)
}
}
if err := scanner.Err(); err != nil {
return nil, false, fmt.Errorf("failed to scan migration: %w", err)
}
// EOF
switch stateMachine.get() {
case start:
return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
case gooseStatementBeginUp, gooseStatementBeginDown:
return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
}
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
}
return stmts, useTx, nil
}
type annotation string
const (
annotationUp annotation = "Up"
annotationDown annotation = "Down"
annotationStatementBegin annotation = "StatementBegin"
annotationStatementEnd annotation = "StatementEnd"
annotationNoTransaction annotation = "NO TRANSACTION"
annotationEnvsubOn annotation = "ENVSUB ON"
annotationEnvsubOff annotation = "ENVSUB OFF"
)
var supportedAnnotations = map[annotation]struct{}{
annotationUp: {},
annotationDown: {},
annotationStatementBegin: {},
annotationStatementEnd: {},
annotationNoTransaction: {},
annotationEnvsubOn: {},
annotationEnvsubOff: {},
}
var (
errEmptyAnnotation = errors.New("empty annotation")
errInvalidAnnotation = errors.New("invalid annotation")
)
// extractAnnotation extracts the annotation from the line.
// All annotations must be in format: "-- +goose [annotation]"
// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF
func extractAnnotation(line string) (annotation, error) {
// If line contains leading whitespace - return error.
if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") {
return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation)
}
// Extract the annotation from the line, by removing the leading "--"
cmd := strings.ReplaceAll(line, "--", "")
// Extract the annotation from the line, by removing the leading "+goose"
cmd = strings.Replace(cmd, "+goose", "", 1)
if strings.Contains(cmd, "+goose") {
return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation)
}
// Remove leading and trailing whitespace from the annotation command.
cmd = strings.TrimSpace(cmd)
if cmd == "" {
return "", errEmptyAnnotation
}
a := annotation(cmd)
for s := range supportedAnnotations {
if strings.EqualFold(string(s), string(a)) {
return s, nil
}
}
return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation)
}
func missingSemicolonError(state parserState, direction Direction, s string) error {
return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?",
state,
direction,
s,
)
}
type envWrapper struct{}
var _ interpolate.Env = (*envWrapper)(nil)
func (e *envWrapper) Get(key string) (string, bool) {
return os.LookupEnv(key)
}
func cleanupStatement(input string) string {
return strings.TrimSpace(input)
}
// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
scanBufPtr := bufferPool.Get().(*[]byte)
scanBuf := *scanBufPtr
defer bufferPool.Put(scanBufPtr)
prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Buffer(scanBuf, scanBufSize)
scanner.Split(bufio.ScanWords)
for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}
return strings.HasSuffix(prev, ";")
}