mirror of https://github.com/pressly/goose.git
fix: add context-aware functions to goose validate (#662)
parent
1c856e67b6
commit
a52c60d6fb
|
@ -19,7 +19,6 @@ import (
|
|||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/internal/cfg"
|
||||
"github.com/pressly/goose/v3/internal/migrationstats"
|
||||
"github.com/pressly/goose/v3/internal/migrationstats/migrationstatsos"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -320,8 +319,10 @@ func printValidate(filename string, verbose bool) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileWalker := migrationstatsos.NewFileWalker(filenames...)
|
||||
stats, err := migrationstats.GatherStats(fileWalker, false)
|
||||
stats, err := migrationstats.GatherStats(
|
||||
migrationstats.NewFileWalker(filenames...),
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -7,11 +7,14 @@ import (
|
|||
"go/parser"
|
||||
"go/token"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
registerGoFuncName = "AddMigration"
|
||||
registerGoFuncNameNoTx = "AddMigrationNoTx"
|
||||
registerGoFuncName = "AddMigration"
|
||||
registerGoFuncNameNoTx = "AddMigrationNoTx"
|
||||
registerGoFuncNameContext = "AddMigrationContext"
|
||||
registerGoFuncNameNoTxContext = "AddMigrationNoTxContext"
|
||||
)
|
||||
|
||||
type goMigration struct {
|
||||
|
@ -72,10 +75,10 @@ func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) {
|
|||
funcName := sel.Sel.Name
|
||||
b := false
|
||||
switch funcName {
|
||||
case registerGoFuncName:
|
||||
case registerGoFuncName, registerGoFuncNameContext:
|
||||
b = true
|
||||
gf.useTx = &b
|
||||
case registerGoFuncNameNoTx:
|
||||
case registerGoFuncNameNoTx, registerGoFuncNameNoTxContext:
|
||||
gf.useTx = &b
|
||||
default:
|
||||
continue
|
||||
|
@ -107,11 +110,15 @@ func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) {
|
|||
}
|
||||
// validation
|
||||
switch gf.name {
|
||||
case registerGoFuncName, registerGoFuncNameNoTx:
|
||||
case registerGoFuncName, registerGoFuncNameNoTx, registerGoFuncNameContext, registerGoFuncNameNoTxContext:
|
||||
default:
|
||||
return nil, fmt.Errorf("goose register function must be one of: %s or %s",
|
||||
registerGoFuncName,
|
||||
registerGoFuncNameNoTx,
|
||||
return nil, fmt.Errorf("goose register function must be one of: %s",
|
||||
strings.Join([]string{
|
||||
registerGoFuncName,
|
||||
registerGoFuncNameNoTx,
|
||||
registerGoFuncNameContext,
|
||||
registerGoFuncNameNoTxContext,
|
||||
}, ", "),
|
||||
)
|
||||
}
|
||||
if gf.useTx == nil {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package migrationstats
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -38,6 +40,47 @@ func TestParsingGoMigrations(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGoMigrationStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := "../../tests/gomigrations/success/testdata"
|
||||
all, err := os.ReadDir(base)
|
||||
check.NoError(t, err)
|
||||
check.Equal(t, len(all), 16)
|
||||
files := make([]string, 0, len(all))
|
||||
for _, f := range all {
|
||||
files = append(files, filepath.Join(base, f.Name()))
|
||||
}
|
||||
stats, err := GatherStats(NewFileWalker(files...), false)
|
||||
check.NoError(t, err)
|
||||
check.Equal(t, len(stats), 16)
|
||||
checkGoStats(t, stats[0], "001_up_down.go", 1, 1, 1, true)
|
||||
checkGoStats(t, stats[1], "002_up_only.go", 2, 1, 0, true)
|
||||
checkGoStats(t, stats[2], "003_down_only.go", 3, 0, 1, true)
|
||||
checkGoStats(t, stats[3], "004_empty.go", 4, 0, 0, true)
|
||||
checkGoStats(t, stats[4], "005_up_down_no_tx.go", 5, 1, 1, false)
|
||||
checkGoStats(t, stats[5], "006_up_only_no_tx.go", 6, 1, 0, false)
|
||||
checkGoStats(t, stats[6], "007_down_only_no_tx.go", 7, 0, 1, false)
|
||||
checkGoStats(t, stats[7], "008_empty_no_tx.go", 8, 0, 0, false)
|
||||
checkGoStats(t, stats[8], "009_up_down_ctx.go", 9, 1, 1, true)
|
||||
checkGoStats(t, stats[9], "010_up_only_ctx.go", 10, 1, 0, true)
|
||||
checkGoStats(t, stats[10], "011_down_only_ctx.go", 11, 0, 1, true)
|
||||
checkGoStats(t, stats[11], "012_empty_ctx.go", 12, 0, 0, true)
|
||||
checkGoStats(t, stats[12], "013_up_down_no_tx_ctx.go", 13, 1, 1, false)
|
||||
checkGoStats(t, stats[13], "014_up_only_no_tx_ctx.go", 14, 1, 0, false)
|
||||
checkGoStats(t, stats[14], "015_down_only_no_tx_ctx.go", 15, 0, 1, false)
|
||||
checkGoStats(t, stats[15], "016_empty_no_tx_ctx.go", 16, 0, 0, false)
|
||||
}
|
||||
|
||||
func checkGoStats(t *testing.T, stats *Stats, filename string, version int64, upCount, downCount int, tx bool) {
|
||||
t.Helper()
|
||||
check.Equal(t, filepath.Base(stats.FileName), filename)
|
||||
check.Equal(t, stats.Version, version)
|
||||
check.Equal(t, stats.UpCount, upCount)
|
||||
check.Equal(t, stats.DownCount, downCount)
|
||||
check.Equal(t, stats.Tx, tx)
|
||||
}
|
||||
|
||||
func TestParsingGoMigrationsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := parseGoFile(strings.NewReader(emptyInit))
|
||||
|
@ -46,7 +89,7 @@ func TestParsingGoMigrationsError(t *testing.T) {
|
|||
|
||||
_, err = parseGoFile(strings.NewReader(wrongName))
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "AddMigration or AddMigrationNoTx")
|
||||
check.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext")
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
package migrationstatsos
|
||||
package migrationstats
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/migrationstats"
|
||||
)
|
||||
|
||||
// NewFileWalker returns a new FileWalker for the given filenames.
|
||||
//
|
||||
// Filenames without a .sql or .go extension are ignored.
|
||||
func NewFileWalker(filenames ...string) migrationstats.FileWalker {
|
||||
func NewFileWalker(filenames ...string) FileWalker {
|
||||
return &fileWalker{
|
||||
filenames: filenames,
|
||||
}
|
||||
|
@ -21,7 +19,7 @@ type fileWalker struct {
|
|||
filenames []string
|
||||
}
|
||||
|
||||
var _ migrationstats.FileWalker = (*fileWalker)(nil)
|
||||
var _ FileWalker = (*fileWalker)(nil)
|
||||
|
||||
func (f *fileWalker) Walk(fn func(filename string, r io.Reader) error) error {
|
||||
for _, filename := range f.filenames {
|
Loading…
Reference in New Issue