feat: Make goose annotations case-insensitive (#704)

pull/712/head
Oleg Balunenko 2024-03-04 17:35:26 +03:00 committed by GitHub
parent 76946ccca3
commit 48100ea926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 165 additions and 13 deletions

View File

@ -121,13 +121,19 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
if stateMachine.get() == start && strings.TrimSpace(line) == "" {
continue
}
// TODO(mf): validate annotations to avoid common user errors:
// https://github.com/pressly/goose/issues/163#issuecomment-501736725
if strings.HasPrefix(line, "--") {
cmd := strings.TrimSpace(strings.TrimPrefix(line, "--"))
// Check for annotations.
// All annotations must be in format: "-- +goose [annotation]"
if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") {
var cmd annotation
cmd, err = extractAnnotation(line)
if err != nil {
return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err)
}
switch cmd {
case "+goose Up":
case annotationUp:
switch stateMachine.get() {
case start:
stateMachine.set(gooseUp)
@ -136,7 +142,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
}
continue
case "+goose Down":
case annotationDown:
switch stateMachine.get() {
case gooseUp, gooseStatementEndUp:
// If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a
@ -151,7 +157,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
}
continue
case "+goose StatementBegin":
case annotationStatementBegin:
switch stateMachine.get() {
case gooseUp, gooseStatementEndUp:
stateMachine.set(gooseStatementBeginUp)
@ -162,7 +168,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
}
continue
case "+goose StatementEnd":
case annotationStatementEnd:
switch stateMachine.get() {
case gooseStatementBeginUp:
stateMachine.set(gooseStatementEndUp)
@ -172,17 +178,20 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
}
case "+goose NO TRANSACTION":
case annotationNoTransaction:
useTx = false
continue
case "+goose ENVSUB ON":
case annotationEnvsubOn:
useEnvsub = true
continue
case "+goose ENVSUB OFF":
case annotationEnvsubOff:
useEnvsub = false
continue
default:
return nil, false, fmt.Errorf("unknown annotation: %q", cmd)
}
}
// Once we've started parsing a statement the buffer is no longer empty,
@ -277,6 +286,70 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
return stmts, useTx, nil
}
type annotation string
const (
annotationUp annotation = "Up"
annotationDown annotation = "Down"
annotationStatementBegin annotation = "StatementBegin"
annotationStatementEnd annotation = "StatementEnd"
annotationNoTransaction annotation = "NO TRANSACTION"
annotationEnvsubOn annotation = "ENVSUB ON"
annotationEnvsubOff annotation = "ENVSUB OFF"
)
var supportedAnnotations = map[annotation]struct{}{
annotationUp: {},
annotationDown: {},
annotationStatementBegin: {},
annotationStatementEnd: {},
annotationNoTransaction: {},
annotationEnvsubOn: {},
annotationEnvsubOff: {},
}
var (
errEmptyAnnotation = errors.New("empty annotation")
errInvalidAnnotation = errors.New("invalid annotation")
)
// extractAnnotation extracts the annotation from the line.
// All annotations must be in format: "-- +goose [annotation]"
// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF
func extractAnnotation(line string) (annotation, error) {
// If line contains leading whitespace - return error.
if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") {
return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation)
}
// Extract the annotation from the line, by removing the leading "--"
cmd := strings.ReplaceAll(line, "--", "")
// Extract the annotation from the line, by removing the leading "+goose"
cmd = strings.Replace(cmd, "+goose", "", 1)
if strings.Contains(cmd, "+goose") {
return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation)
}
// Remove leading and trailing whitespace from the annotation command.
cmd = strings.TrimSpace(cmd)
if cmd == "" {
return "", errEmptyAnnotation
}
a := annotation(cmd)
for s := range supportedAnnotations {
if strings.EqualFold(string(s), string(a)) {
return s, nil
}
}
return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation)
}
func missingSemicolonError(state parserState, direction Direction, s string) error {
return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?",
state,

View File

@ -507,3 +507,82 @@ CREATE TABLE post (
check.HasError(t, err)
check.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:")
}
func Test_extractAnnotation(t *testing.T) {
tests := []struct {
name string
input string
want annotation
wantErr func(t *testing.T, err error)
}{
{
name: "Up",
input: "-- +goose Up",
want: annotationUp,
wantErr: check.NoError,
},
{
name: "Down",
input: "-- +goose Down",
want: annotationDown,
wantErr: check.NoError,
},
{
name: "StmtBegin",
input: "-- +goose StatementBegin",
want: annotationStatementBegin,
wantErr: check.NoError,
},
{
name: "NoTransact",
input: "-- +goose NO TRANSACTION",
want: annotationNoTransaction,
wantErr: check.NoError,
},
{
name: "Unsupported",
input: "-- +goose unsupported",
want: "",
wantErr: check.HasError,
},
{
name: "Empty",
input: "-- +goose",
want: "",
wantErr: check.HasError,
},
{
name: "statement with spaces and Uppercase",
input: "-- +goose UP ",
want: annotationUp,
wantErr: check.NoError,
},
{
name: "statement with leading whitespace - error",
input: " -- +goose UP ",
want: "",
wantErr: check.HasError,
},
{
name: "statement with leading \t - error",
input: "\t-- +goose UP ",
want: "",
wantErr: check.HasError,
},
{
name: "multiple +goose annotations - error",
input: "-- +goose +goose Up",
want: "",
wantErr: check.HasError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractAnnotation(tt.input)
tt.wantErr(t, err)
check.Equal(t, got, tt.want)
})
}
}

View File

@ -1,4 +1,4 @@
-- +goose Up
-- +goose UP
CREATE TABLE emp (
empname text,
salary integer,
@ -6,7 +6,7 @@ CREATE TABLE emp (
last_user text
);
-- +goose StatementBegin
-- +goose statementBegin
CREATE FUNCTION emp_stamp() RETURNS trigger AS $emp_stamp$
BEGIN
-- Check that empname and salary are given