fix: add context-aware functions to goose validate (#662)

pull/627/head
Michael Fridman 2023-12-12 09:04:02 -05:00 committed by GitHub
parent 1c856e67b6
commit a52c60d6fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 17 deletions

View File

@ -19,7 +19,6 @@ import (
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/cfg"
"github.com/pressly/goose/v3/internal/migrationstats"
"github.com/pressly/goose/v3/internal/migrationstats/migrationstatsos"
)
var (
@ -320,8 +319,10 @@ func printValidate(filename string, verbose bool) error {
if err != nil {
return err
}
fileWalker := migrationstatsos.NewFileWalker(filenames...)
stats, err := migrationstats.GatherStats(fileWalker, false)
stats, err := migrationstats.GatherStats(
migrationstats.NewFileWalker(filenames...),
false,
)
if err != nil {
return err
}

View File

@ -7,11 +7,14 @@ import (
"go/parser"
"go/token"
"io"
"strings"
)
const (
registerGoFuncName = "AddMigration"
registerGoFuncNameNoTx = "AddMigrationNoTx"
registerGoFuncName = "AddMigration"
registerGoFuncNameNoTx = "AddMigrationNoTx"
registerGoFuncNameContext = "AddMigrationContext"
registerGoFuncNameNoTxContext = "AddMigrationNoTxContext"
)
type goMigration struct {
@ -72,10 +75,10 @@ func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) {
funcName := sel.Sel.Name
b := false
switch funcName {
case registerGoFuncName:
case registerGoFuncName, registerGoFuncNameContext:
b = true
gf.useTx = &b
case registerGoFuncNameNoTx:
case registerGoFuncNameNoTx, registerGoFuncNameNoTxContext:
gf.useTx = &b
default:
continue
@ -107,11 +110,15 @@ func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) {
}
// validation
switch gf.name {
case registerGoFuncName, registerGoFuncNameNoTx:
case registerGoFuncName, registerGoFuncNameNoTx, registerGoFuncNameContext, registerGoFuncNameNoTxContext:
default:
return nil, fmt.Errorf("goose register function must be one of: %s or %s",
registerGoFuncName,
registerGoFuncNameNoTx,
return nil, fmt.Errorf("goose register function must be one of: %s",
strings.Join([]string{
registerGoFuncName,
registerGoFuncNameNoTx,
registerGoFuncNameContext,
registerGoFuncNameNoTxContext,
}, ", "),
)
}
if gf.useTx == nil {

View File

@ -1,6 +1,8 @@
package migrationstats
import (
"os"
"path/filepath"
"strings"
"testing"
@ -38,6 +40,47 @@ func TestParsingGoMigrations(t *testing.T) {
}
}
func TestGoMigrationStats(t *testing.T) {
t.Parallel()
base := "../../tests/gomigrations/success/testdata"
all, err := os.ReadDir(base)
check.NoError(t, err)
check.Equal(t, len(all), 16)
files := make([]string, 0, len(all))
for _, f := range all {
files = append(files, filepath.Join(base, f.Name()))
}
stats, err := GatherStats(NewFileWalker(files...), false)
check.NoError(t, err)
check.Equal(t, len(stats), 16)
checkGoStats(t, stats[0], "001_up_down.go", 1, 1, 1, true)
checkGoStats(t, stats[1], "002_up_only.go", 2, 1, 0, true)
checkGoStats(t, stats[2], "003_down_only.go", 3, 0, 1, true)
checkGoStats(t, stats[3], "004_empty.go", 4, 0, 0, true)
checkGoStats(t, stats[4], "005_up_down_no_tx.go", 5, 1, 1, false)
checkGoStats(t, stats[5], "006_up_only_no_tx.go", 6, 1, 0, false)
checkGoStats(t, stats[6], "007_down_only_no_tx.go", 7, 0, 1, false)
checkGoStats(t, stats[7], "008_empty_no_tx.go", 8, 0, 0, false)
checkGoStats(t, stats[8], "009_up_down_ctx.go", 9, 1, 1, true)
checkGoStats(t, stats[9], "010_up_only_ctx.go", 10, 1, 0, true)
checkGoStats(t, stats[10], "011_down_only_ctx.go", 11, 0, 1, true)
checkGoStats(t, stats[11], "012_empty_ctx.go", 12, 0, 0, true)
checkGoStats(t, stats[12], "013_up_down_no_tx_ctx.go", 13, 1, 1, false)
checkGoStats(t, stats[13], "014_up_only_no_tx_ctx.go", 14, 1, 0, false)
checkGoStats(t, stats[14], "015_down_only_no_tx_ctx.go", 15, 0, 1, false)
checkGoStats(t, stats[15], "016_empty_no_tx_ctx.go", 16, 0, 0, false)
}
func checkGoStats(t *testing.T, stats *Stats, filename string, version int64, upCount, downCount int, tx bool) {
t.Helper()
check.Equal(t, filepath.Base(stats.FileName), filename)
check.Equal(t, stats.Version, version)
check.Equal(t, stats.UpCount, upCount)
check.Equal(t, stats.DownCount, downCount)
check.Equal(t, stats.Tx, tx)
}
func TestParsingGoMigrationsError(t *testing.T) {
t.Parallel()
_, err := parseGoFile(strings.NewReader(emptyInit))
@ -46,7 +89,7 @@ func TestParsingGoMigrationsError(t *testing.T) {
_, err = parseGoFile(strings.NewReader(wrongName))
check.HasError(t, err)
check.Contains(t, err.Error(), "AddMigration or AddMigrationNoTx")
check.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext")
}
var (

View File

@ -1,17 +1,15 @@
package migrationstatsos
package migrationstats
import (
"io"
"os"
"path/filepath"
"github.com/pressly/goose/v3/internal/migrationstats"
)
// NewFileWalker returns a new FileWalker for the given filenames.
//
// Filenames without a .sql or .go extension are ignored.
func NewFileWalker(filenames ...string) migrationstats.FileWalker {
func NewFileWalker(filenames ...string) FileWalker {
return &fileWalker{
filenames: filenames,
}
@ -21,7 +19,7 @@ type fileWalker struct {
filenames []string
}
var _ migrationstats.FileWalker = (*fileWalker)(nil)
var _ FileWalker = (*fileWalker)(nil)
func (f *fileWalker) Walk(fn func(filename string, r io.Reader) error) error {
for _, filename := range f.filenames {