feat: `goose validate` command (#449)

pull/470/head
Michael Fridman 2023-02-25 14:33:10 -05:00 committed by GitHub
parent 60610d3ae3
commit 8c25e3bd17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 554 additions and 6 deletions

View File

@ -21,7 +21,7 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: v1.50.1
version: latest
# Optional: working directory, useful for monorepos
# working-directory: somedir

View File

@ -20,7 +20,7 @@ lint: tools
.PHONY: tools
tools:
@go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.50.1
@go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test-packages:
go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
@ -49,4 +49,4 @@ docker-start-postgres:
-e POSTGRES_DB=${GOOSE_POSTGRES_DBNAME} \
-p ${GOOSE_POSTGRES_PORT}:5432 \
-l goose_test \
postgres:14-alpine
postgres:14-alpine -c log_statement=all

View File

@ -7,12 +7,18 @@ import (
"io/fs"
"log"
"os"
"path/filepath"
"runtime/debug"
"sort"
"strconv"
"strings"
"text/tabwriter"
"text/template"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/cfg"
"github.com/pressly/goose/v3/internal/migrationstats"
"github.com/pressly/goose/v3/internal/migrationstats/migrationstatsos"
)
var (
@ -95,6 +101,11 @@ func main() {
fmt.Printf("%s=%q\n", env.Name, env.Value)
}
return
case "validate":
if err := printValidate(*dir, *verbose); err != nil {
log.Fatalf("goose validate: %v", err)
}
return
}
args = mergeArgs(args)
@ -278,3 +289,59 @@ func gooseInit(dir string) error {
}
return goose.CreateWithTemplate(nil, dir, sqlMigrationTemplate, "initial", "sql")
}
func gatherFilenames(filename string) ([]string, error) {
stat, err := os.Stat(filename)
if err != nil {
return nil, err
}
var filenames []string
if stat.IsDir() {
for _, pattern := range []string{"*.sql", "*.go"} {
file, err := filepath.Glob(filepath.Join(filename, pattern))
if err != nil {
return nil, err
}
filenames = append(filenames, file...)
}
} else {
filenames = append(filenames, filename)
}
sort.Strings(filenames)
return filenames, nil
}
func printValidate(filename string, verbose bool) error {
filenames, err := gatherFilenames(filename)
if err != nil {
return err
}
fileWalker := migrationstatsos.NewFileWalker(filenames...)
stats, err := migrationstats.GatherStats(fileWalker, false)
if err != nil {
return err
}
// TODO(mf): we should introduce a --debug flag, which allows printing
// more internal debug information and leave verbose for additional information.
if !verbose {
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', tabwriter.TabIndent)
fmtPattern := "%v\t%v\t%v\t%v\t%v\t\n"
fmt.Fprintf(w, fmtPattern, "Type", "Txn", "Up", "Down", "Name")
fmt.Fprintf(w, fmtPattern, "────", "───", "──", "────", "────")
for _, m := range stats {
txnStr := "✔"
if !m.Tx {
txnStr = "✘"
}
fmt.Fprintf(w, fmtPattern,
strings.TrimPrefix(filepath.Ext(m.FileName), "."),
txnStr,
m.UpCount,
m.DownCount,
filepath.Base(m.FileName),
)
}
return w.Flush()
}

View File

@ -0,0 +1,129 @@
package migrationstats
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
)
const (
registerGoFuncName = "AddMigration"
registerGoFuncNameNoTx = "AddMigrationNoTx"
)
type goMigration struct {
name string
useTx *bool
upFuncName, downFuncName string
}
func parseGoFile(r io.Reader) (*goMigration, error) {
astFile, err := parser.ParseFile(
token.NewFileSet(),
"", // filename
r,
// We don't need to resolve imports, so we can skip it.
// This speeds up the parsing process.
// See https://github.com/golang/go/issues/46485
parser.SkipObjectResolution,
)
if err != nil {
return nil, err
}
for _, decl := range astFile.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok || fn == nil || fn.Name == nil {
continue
}
if fn.Name.Name == "init" {
return parseInitFunc(fn)
}
}
return nil, errors.New("no init function")
}
func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) {
if fd == nil {
return nil, fmt.Errorf("function declaration must not be nil")
}
if fd.Body == nil {
return nil, fmt.Errorf("no function body")
}
if len(fd.Body.List) == 0 {
return nil, fmt.Errorf("no registered goose functions")
}
gf := new(goMigration)
for _, statement := range fd.Body.List {
expr, ok := statement.(*ast.ExprStmt)
if !ok {
continue
}
call, ok := expr.X.(*ast.CallExpr)
if !ok {
continue
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok || sel == nil {
continue
}
funcName := sel.Sel.Name
b := false
switch funcName {
case registerGoFuncName:
b = true
gf.useTx = &b
case registerGoFuncNameNoTx:
gf.useTx = &b
default:
continue
}
if gf.name != "" {
return nil, fmt.Errorf("found duplicate registered functions:\nprevious: %v\ncurrent: %v", gf.name, funcName)
}
gf.name = funcName
if len(call.Args) != 2 {
return nil, fmt.Errorf("registered goose functions have 2 arguments: got %d", len(call.Args))
}
getNameFromExpr := func(expr ast.Expr) (string, error) {
arg, ok := expr.(*ast.Ident)
if !ok {
return "", fmt.Errorf("failed to assert argument identifer: got %T", arg)
}
return arg.Name, nil
}
var err error
gf.upFuncName, err = getNameFromExpr(call.Args[0])
if err != nil {
return nil, err
}
gf.downFuncName, err = getNameFromExpr(call.Args[1])
if err != nil {
return nil, err
}
}
// validation
switch gf.name {
case registerGoFuncName, registerGoFuncNameNoTx:
default:
return nil, fmt.Errorf("goose register function must be one of: %s or %s",
registerGoFuncName,
registerGoFuncNameNoTx,
)
}
if gf.useTx == nil {
return nil, errors.New("validation error: failed to identify transaction: got nil bool")
}
// The up and down functions can either be named Go functions or "nil", an
// empty string means there is a flaw in our parsing logic of the Go source code.
if gf.upFuncName == "" {
return nil, fmt.Errorf("validation error: up function is empty string")
}
if gf.downFuncName == "" {
return nil, fmt.Errorf("validation error: down function is empty string")
}
return gf, nil
}

View File

@ -0,0 +1,47 @@
package migrationstats
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"github.com/pressly/goose/v3/internal/sqlparser"
)
type sqlMigration struct {
useTx bool
upCount, downCount int
}
func parseSQLFile(r io.Reader, debug bool) (*sqlMigration, error) {
by, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
upStatements, txUp, err := sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionUp,
debug,
)
if err != nil {
return nil, err
}
downStatements, txDown, err := sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionDown,
debug,
)
if err != nil {
return nil, err
}
// This is a sanity check to ensure that the parser is behaving as expected.
if txUp != txDown {
return nil, fmt.Errorf("up and down statements must have the same transaction mode")
}
return &sqlMigration{
useTx: txUp,
upCount: len(upStatements),
downCount: len(downStatements),
}, nil
}

View File

@ -0,0 +1,78 @@
package migrationstats
import (
"fmt"
"io"
"path/filepath"
"github.com/pressly/goose/v3"
)
// FileWalker walks all files for GatherStats.
type FileWalker interface {
// Walk invokes fn for each file.
Walk(fn func(filename string, r io.Reader) error) error
}
// Stats contains the stats for a migration file.
type Stats struct {
// FileName is the name of the file.
FileName string
// Version is the version of the migration.
Version int64
// Tx is true if the .sql migration file has a +goose NO TRANSACTION annotation
// or the .go migration file calls AddMigrationNoTx.
Tx bool
// UpCount is the number of statements in the Up migration.
UpCount int
// DownCount is the number of statements in the Down migration.
DownCount int
}
// GatherStats returns the migration file stats.
func GatherStats(fw FileWalker, debug bool) ([]*Stats, error) {
var stats []*Stats
err := fw.Walk(func(filename string, r io.Reader) error {
version, err := goose.NumericComponent(filename)
if err != nil {
return fmt.Errorf("failed to get version from file %q: %w", filename, err)
}
var up, down int
var tx bool
switch filepath.Ext(filename) {
case ".sql":
m, err := parseSQLFile(r, debug)
if err != nil {
return fmt.Errorf("failed to parse file %q: %w", filename, err)
}
up, down = m.upCount, m.downCount
tx = m.useTx
case ".go":
m, err := parseGoFile(r)
if err != nil {
return fmt.Errorf("failed to parse file %q: %w", filename, err)
}
up, down = nilAsNumber(m.upFuncName), nilAsNumber(m.downFuncName)
tx = *m.useTx
}
stats = append(stats, &Stats{
FileName: filename,
Version: version,
Tx: tx,
UpCount: up,
DownCount: down,
})
return nil
})
if err != nil {
return nil, err
}
return stats, nil
}
func nilAsNumber(s string) int {
if s != "nil" {
return 1
}
return 0
}

View File

@ -0,0 +1,181 @@
package migrationstats
import (
"strings"
"testing"
"github.com/pressly/goose/v3/internal/check"
)
func TestParsingGoMigrations(t *testing.T) {
tests := []struct {
name string
input string
wantUpName, wantDownName string
wantTx bool
}{
// AddMigration
{"upAndDown", upAndDown, "up001", "down001", true},
{"downOnly", downOnly, "nil", "down002", true},
{"upOnly", upOnly, "up003", "nil", true},
{"upAndDownNil", upAndDownNil, "nil", "nil", true},
// AddMigrationNoTx
{"upAndDownNoTx", upAndDownNoTx, "up001", "down001", false},
{"downOnlyNoTx", downOnlyNoTx, "nil", "down002", false},
{"upOnlyNoTx", upOnlyNoTx, "up003", "nil", false},
{"upAndDownNilNoTx", upAndDownNilNoTx, "nil", "nil", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
g, err := parseGoFile(strings.NewReader(tc.input))
check.NoError(t, err)
check.Equal(t, g.useTx != nil, true)
check.Bool(t, *g.useTx, tc.wantTx)
check.Equal(t, g.downFuncName, tc.wantDownName)
check.Equal(t, g.upFuncName, tc.wantUpName)
})
}
}
func TestParsingGoMigrationsError(t *testing.T) {
_, err := parseGoFile(strings.NewReader(emptyInit))
check.HasError(t, err)
check.Contains(t, err.Error(), "no registered goose functions")
_, err = parseGoFile(strings.NewReader(wrongName))
check.HasError(t, err)
check.Contains(t, err.Error(), "AddMigration or AddMigrationNoTx")
}
var (
upAndDown = `package foo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up001, down001)
}
func up001(tx *sql.Tx) error { return nil }
func down001(tx *sql.Tx) error { return nil }`
downOnly = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(nil, down002)
}
func down002(tx *sql.Tx) error { return nil }`
upOnly = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up003, nil)
}
func up003(tx *sql.Tx) error { return nil }`
upAndDownNil = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(nil, nil)
}`
)
var (
upAndDownNoTx = `package foo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(up001, down001)
}
func up001(db *sql.DB) error { return nil }
func down001(db *sql.DB) error { return nil }`
downOnlyNoTx = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(nil, down002)
}
func down002(db *sql.DB) error { return nil }`
upOnlyNoTx = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(up003, nil)
}
func up003(db *sql.DB) error { return nil }`
upAndDownNilNoTx = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(nil, nil)
}`
)
var (
emptyInit = `package testgo
func init() {}`
wrongName = `package testgo
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationWrongName(nil, nil)
}`
)

View File

@ -0,0 +1,46 @@
package migrationstatsos
import (
"io"
"os"
"path/filepath"
"github.com/pressly/goose/v3/internal/migrationstats"
)
// NewFileWalker returns a new FileWalker for the given filenames.
//
// Filenames without a .sql or .go extension are ignored.
func NewFileWalker(filenames ...string) migrationstats.FileWalker {
return &fileWalker{
filenames: filenames,
}
}
type fileWalker struct {
filenames []string
}
var _ migrationstats.FileWalker = (*fileWalker)(nil)
func (f *fileWalker) Walk(fn func(filename string, r io.Reader) error) error {
for _, filename := range f.filenames {
ext := filepath.Ext(filename)
if ext != ".sql" && ext != ".go" {
continue
}
if err := walk(filename, fn); err != nil {
return err
}
}
return nil
}
func walk(filename string, fn func(filename string, r io.Reader) error) error {
file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()
return fn(filename, file)
}

View File

@ -89,7 +89,7 @@ var bufferPool = sync.Pool{
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func ParseSQLMigration(r io.Reader, direction Direction, verbose bool) (stmts []string, useTx bool, err error) {
func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []string, useTx bool, err error) {
scanBufPtr := bufferPool.Get().(*[]byte)
scanBuf := *scanBufPtr
defer bufferPool.Put(scanBufPtr)
@ -97,13 +97,13 @@ func ParseSQLMigration(r io.Reader, direction Direction, verbose bool) (stmts []
scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)
stateMachine := newStateMachine(start, verbose)
stateMachine := newStateMachine(start, debug)
useTx = true
var buf bytes.Buffer
for scanner.Scan() {
line := scanner.Text()
if verbose {
if debug {
log.Println(line)
}
if stateMachine.get() == start && strings.TrimSpace(line) == "" {