diff --git a/internal/sqlparser/parser.go b/internal/sqlparser/parser.go index bba7990..f08ae3e 100644 --- a/internal/sqlparser/parser.go +++ b/internal/sqlparser/parser.go @@ -121,13 +121,19 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st if stateMachine.get() == start && strings.TrimSpace(line) == "" { continue } - // TODO(mf): validate annotations to avoid common user errors: - // https://github.com/pressly/goose/issues/163#issuecomment-501736725 - if strings.HasPrefix(line, "--") { - cmd := strings.TrimSpace(strings.TrimPrefix(line, "--")) + + // Check for annotations. + // All annotations must be in format: "-- +goose [annotation]" + if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") { + var cmd annotation + + cmd, err = extractAnnotation(line) + if err != nil { + return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err) + } switch cmd { - case "+goose Up": + case annotationUp: switch stateMachine.get() { case start: stateMachine.set(gooseUp) @@ -136,7 +142,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st } continue - case "+goose Down": + case annotationDown: switch stateMachine.get() { case gooseUp, gooseStatementEndUp: // If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a @@ -151,7 +157,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st } continue - case "+goose StatementBegin": + case annotationStatementBegin: switch stateMachine.get() { case gooseUp, gooseStatementEndUp: stateMachine.set(gooseStatementBeginUp) @@ -162,7 +168,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st } continue - case "+goose StatementEnd": + case annotationStatementEnd: switch stateMachine.get() { case gooseStatementBeginUp: stateMachine.set(gooseStatementEndUp) @@ -172,17 +178,20 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations") } - case "+goose NO TRANSACTION": + case annotationNoTransaction: useTx = false continue - case "+goose ENVSUB ON": + case annotationEnvsubOn: useEnvsub = true continue - case "+goose ENVSUB OFF": + case annotationEnvsubOff: useEnvsub = false continue + + default: + return nil, false, fmt.Errorf("unknown annotation: %q", cmd) } } // Once we've started parsing a statement the buffer is no longer empty, @@ -277,6 +286,70 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st return stmts, useTx, nil } +type annotation string + +const ( + annotationUp annotation = "Up" + annotationDown annotation = "Down" + annotationStatementBegin annotation = "StatementBegin" + annotationStatementEnd annotation = "StatementEnd" + annotationNoTransaction annotation = "NO TRANSACTION" + annotationEnvsubOn annotation = "ENVSUB ON" + annotationEnvsubOff annotation = "ENVSUB OFF" +) + +var supportedAnnotations = map[annotation]struct{}{ + annotationUp: {}, + annotationDown: {}, + annotationStatementBegin: {}, + annotationStatementEnd: {}, + annotationNoTransaction: {}, + annotationEnvsubOn: {}, + annotationEnvsubOff: {}, +} + +var ( + errEmptyAnnotation = errors.New("empty annotation") + errInvalidAnnotation = errors.New("invalid annotation") +) + +// extractAnnotation extracts the annotation from the line. +// All annotations must be in format: "-- +goose [annotation]" +// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF +func extractAnnotation(line string) (annotation, error) { + // If line contains leading whitespace - return error. + if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") { + return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation) + } + + // Extract the annotation from the line, by removing the leading "--" + cmd := strings.ReplaceAll(line, "--", "") + + // Extract the annotation from the line, by removing the leading "+goose" + cmd = strings.Replace(cmd, "+goose", "", 1) + + if strings.Contains(cmd, "+goose") { + return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation) + } + + // Remove leading and trailing whitespace from the annotation command. + cmd = strings.TrimSpace(cmd) + + if cmd == "" { + return "", errEmptyAnnotation + } + + a := annotation(cmd) + + for s := range supportedAnnotations { + if strings.EqualFold(string(s), string(a)) { + return s, nil + } + } + + return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation) +} + func missingSemicolonError(state parserState, direction Direction, s string) error { return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?", state, diff --git a/internal/sqlparser/parser_test.go b/internal/sqlparser/parser_test.go index efe47d7..155717f 100644 --- a/internal/sqlparser/parser_test.go +++ b/internal/sqlparser/parser_test.go @@ -507,3 +507,82 @@ CREATE TABLE post ( check.HasError(t, err) check.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:") } + +func Test_extractAnnotation(t *testing.T) { + tests := []struct { + name string + input string + want annotation + wantErr func(t *testing.T, err error) + }{ + { + name: "Up", + input: "-- +goose Up", + want: annotationUp, + wantErr: check.NoError, + }, + { + name: "Down", + input: "-- +goose Down", + want: annotationDown, + wantErr: check.NoError, + }, + { + name: "StmtBegin", + input: "-- +goose StatementBegin", + want: annotationStatementBegin, + wantErr: check.NoError, + }, + { + name: "NoTransact", + input: "-- +goose NO TRANSACTION", + want: annotationNoTransaction, + wantErr: check.NoError, + }, + { + name: "Unsupported", + input: "-- +goose unsupported", + want: "", + wantErr: check.HasError, + }, + { + name: "Empty", + input: "-- +goose", + want: "", + wantErr: check.HasError, + }, + { + name: "statement with spaces and Uppercase", + input: "-- +goose UP ", + want: annotationUp, + wantErr: check.NoError, + }, + { + name: "statement with leading whitespace - error", + input: " -- +goose UP ", + want: "", + wantErr: check.HasError, + }, + { + name: "statement with leading \t - error", + input: "\t-- +goose UP ", + want: "", + wantErr: check.HasError, + }, + { + name: "multiple +goose annotations - error", + input: "-- +goose +goose Up", + want: "", + wantErr: check.HasError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractAnnotation(tt.input) + tt.wantErr(t, err) + + check.Equal(t, got, tt.want) + }) + } +} diff --git a/internal/sqlparser/testdata/valid-up/test01/input.sql b/internal/sqlparser/testdata/valid-up/test01/input.sql index 1fadbc7..354fc52 100644 --- a/internal/sqlparser/testdata/valid-up/test01/input.sql +++ b/internal/sqlparser/testdata/valid-up/test01/input.sql @@ -1,4 +1,4 @@ --- +goose Up +-- +goose UP CREATE TABLE emp ( empname text, salary integer, @@ -6,7 +6,7 @@ CREATE TABLE emp ( last_user text ); --- +goose StatementBegin +-- +goose statementBegin CREATE FUNCTION emp_stamp() RETURNS trigger AS $emp_stamp$ BEGIN -- Check that empname and salary are given