mirror of https://github.com/pressly/goose.git
feat(experimental): shuffle packages & add explicit provider Go func registration (#616)
parent
68853f91ea
commit
58f8534610
1
go.mod
1
go.mod
|
@ -12,6 +12,7 @@ require (
|
|||
github.com/vertica/vertica-sql-go v1.3.3
|
||||
github.com/ziutek/mymysql v1.5.4
|
||||
go.uber.org/multierr v1.11.0
|
||||
golang.org/x/sync v0.4.0
|
||||
modernc.org/sqlite v1.26.0
|
||||
)
|
||||
|
||||
|
|
3
go.sum
3
go.sum
|
@ -184,7 +184,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
|
||||
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
// Package migrate defines a Migration struct and implements the migration logic for executing Go
|
||||
// and SQL migrations.
|
||||
//
|
||||
// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported.
|
||||
// - For SQL migrations, all three are supported.
|
||||
//
|
||||
// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first
|
||||
// time it is executed.
|
||||
package migrate
|
|
@ -1,166 +0,0 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlextended"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
// Fullpath is the full path to the migration file.
|
||||
//
|
||||
// Example: /path/to/migrations/123_create_users_table.go
|
||||
Fullpath string
|
||||
// Version is the version of the migration.
|
||||
Version int64
|
||||
// Type is the type of migration.
|
||||
Type MigrationType
|
||||
// A migration is either a Go migration or a SQL migration, but never both.
|
||||
//
|
||||
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
|
||||
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
|
||||
// majority of the time migrations are incremental, so it is likely that the user will only want
|
||||
// to run the last few migrations, and there is no need to parse ALL prior migrations.
|
||||
//
|
||||
// Exactly one of these fields will be set:
|
||||
Go *Go
|
||||
// -- or --
|
||||
SQLParsed bool
|
||||
SQL *SQL
|
||||
}
|
||||
|
||||
type MigrationType int
|
||||
|
||||
const (
|
||||
TypeGo MigrationType = iota + 1
|
||||
TypeSQL
|
||||
)
|
||||
|
||||
func (t MigrationType) String() string {
|
||||
switch t {
|
||||
case TypeGo:
|
||||
return "go"
|
||||
case TypeSQL:
|
||||
return "sql"
|
||||
default:
|
||||
// This should never happen.
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migration) UseTx() bool {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
return m.Go.UseTx
|
||||
case TypeSQL:
|
||||
return m.SQL.UseTx
|
||||
default:
|
||||
// This should never happen.
|
||||
panic("unknown migration type: use tx")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migration) IsEmpty(direction bool) bool {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
return m.Go.IsEmpty(direction)
|
||||
case TypeSQL:
|
||||
return m.SQL.IsEmpty(direction)
|
||||
default:
|
||||
// This should never happen.
|
||||
panic("unknown migration type: is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migration) GetSQLStatements(direction bool) ([]string, error) {
|
||||
if m.Type != TypeSQL {
|
||||
return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type)
|
||||
}
|
||||
if m.SQL == nil {
|
||||
return nil, errors.New("sql migration has not been initialized")
|
||||
}
|
||||
if !m.SQLParsed {
|
||||
return nil, errors.New("sql migration has not been parsed")
|
||||
}
|
||||
if direction {
|
||||
return m.SQL.UpStatements, nil
|
||||
}
|
||||
return m.SQL.DownStatements, nil
|
||||
}
|
||||
|
||||
type Go struct {
|
||||
// We used an explicit bool instead of relying on a pointer because registered funcs may be nil.
|
||||
// These are still valid Go and versioned migrations, but they are just empty.
|
||||
//
|
||||
// For example: goose.AddMigration(nil, nil)
|
||||
UseTx bool
|
||||
|
||||
// Only one of these func pairs will be set:
|
||||
UpFn, DownFn func(context.Context, *sql.Tx) error
|
||||
// -- or --
|
||||
UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error
|
||||
}
|
||||
|
||||
func (g *Go) IsEmpty(direction bool) bool {
|
||||
if direction {
|
||||
return g.UpFn == nil && g.UpFnNoTx == nil
|
||||
}
|
||||
return g.DownFn == nil && g.DownFnNoTx == nil
|
||||
}
|
||||
|
||||
func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
var fn func(context.Context, *sql.Tx) error
|
||||
if direction {
|
||||
fn = g.UpFn
|
||||
} else {
|
||||
fn = g.DownFn
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, tx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
var fn func(context.Context, *sql.DB) error
|
||||
if direction {
|
||||
fn = g.UpFnNoTx
|
||||
} else {
|
||||
fn = g.DownFnNoTx
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SQL struct {
|
||||
UseTx bool
|
||||
UpStatements []string
|
||||
DownStatements []string
|
||||
}
|
||||
|
||||
func (s *SQL) IsEmpty(direction bool) bool {
|
||||
if direction {
|
||||
return len(s.UpStatements) == 0
|
||||
}
|
||||
return len(s.DownStatements) == 0
|
||||
}
|
||||
|
||||
func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
|
||||
var statements []string
|
||||
if direction {
|
||||
statements = s.UpStatements
|
||||
} else {
|
||||
statements = s.DownStatements
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
)
|
||||
|
||||
// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
|
||||
// will not be parsed again.
|
||||
//
|
||||
// Important: This function will mutate SQL migrations.
|
||||
func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error {
|
||||
for _, m := range migrations {
|
||||
if m.Type == TypeSQL && !m.SQLParsed {
|
||||
parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.SQLParsed = true
|
||||
m.SQL = parsedSQLMigration
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse is used to determine which direction to parse the SQL migration.
|
||||
type parse int
|
||||
|
||||
const (
|
||||
// parseAll parses all SQL statements in BOTH directions.
|
||||
parseAll parse = iota + 1
|
||||
// parseUp parses all SQL statements in the UP direction.
|
||||
parseUp
|
||||
// parseDown parses all SQL statements in the DOWN direction.
|
||||
parseDown
|
||||
)
|
||||
|
||||
func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) {
|
||||
r, err := fsys.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
by, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := new(SQL)
|
||||
if p == parseAll || p == parseUp {
|
||||
s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
|
||||
bytes.NewReader(by),
|
||||
sqlparser.DirectionUp,
|
||||
debug,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if p == parseAll || p == parseDown {
|
||||
s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
|
||||
bytes.NewReader(by),
|
||||
sqlparser.DirectionDown,
|
||||
debug,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
|
@ -9,15 +9,56 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/internal/migrate"
|
||||
)
|
||||
|
||||
// Source represents a single migration source.
|
||||
//
|
||||
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
|
||||
// the migration has a corresponding file on disk. It will be empty if the migration was registered
|
||||
// manually.
|
||||
type Source struct {
|
||||
// Type is the type of migration.
|
||||
Type MigrationType
|
||||
// Full path to the migration file.
|
||||
//
|
||||
// Example: /path/to/migrations/001_create_users_table.sql
|
||||
Fullpath string
|
||||
// Version is the version of the migration.
|
||||
Version int64
|
||||
}
|
||||
|
||||
func newSource(t MigrationType, fullpath string, version int64) Source {
|
||||
return Source{
|
||||
Type: t,
|
||||
Fullpath: fullpath,
|
||||
Version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// fileSources represents a collection of migration files on the filesystem.
|
||||
type fileSources struct {
|
||||
sqlSources []Source
|
||||
goSources []Source
|
||||
}
|
||||
|
||||
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
|
||||
switch t {
|
||||
case TypeGo:
|
||||
for _, source := range s.goSources {
|
||||
if source.Version == version {
|
||||
return &source
|
||||
}
|
||||
}
|
||||
case TypeSQL:
|
||||
for _, source := range s.sqlSources {
|
||||
if source.Version == version {
|
||||
return &source
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectFileSources scans the file system for migration files that have a numeric prefix (greater
|
||||
// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil,
|
||||
// in which case an empty fileSources is returned.
|
||||
|
@ -69,15 +110,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
|
|||
}
|
||||
switch filepath.Ext(base) {
|
||||
case ".sql":
|
||||
sources.sqlSources = append(sources.sqlSources, Source{
|
||||
Fullpath: fullpath,
|
||||
Version: version,
|
||||
})
|
||||
sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version))
|
||||
case ".go":
|
||||
sources.goSources = append(sources.goSources, Source{
|
||||
Fullpath: fullpath,
|
||||
Version: version,
|
||||
})
|
||||
sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version))
|
||||
default:
|
||||
// Should never happen since we already filtered out all other file types.
|
||||
return nil, fmt.Errorf("unknown migration type: %s", base)
|
||||
|
@ -89,19 +124,17 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
|
|||
return sources, nil
|
||||
}
|
||||
|
||||
func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) {
|
||||
var migrations []*migrate.Migration
|
||||
migrationLookup := make(map[int64]*migrate.Migration)
|
||||
func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) {
|
||||
var migrations []*migration
|
||||
migrationLookup := make(map[int64]*migration)
|
||||
// Add all SQL migrations to the list of migrations.
|
||||
for _, s := range sources.sqlSources {
|
||||
m := &migrate.Migration{
|
||||
Type: migrate.TypeSQL,
|
||||
Fullpath: s.Fullpath,
|
||||
Version: s.Version,
|
||||
SQLParsed: false,
|
||||
for _, source := range sources.sqlSources {
|
||||
m := &migration{
|
||||
Source: source,
|
||||
SQL: nil, // SQL migrations are parsed lazily.
|
||||
}
|
||||
migrations = append(migrations, m)
|
||||
migrationLookup[s.Version] = m
|
||||
migrationLookup[source.Version] = m
|
||||
}
|
||||
// If there are no Go files in the filesystem and no registered Go migrations, return early.
|
||||
if len(sources.goSources) == 0 && len(registerd) == 0 {
|
||||
|
@ -127,38 +160,41 @@ func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migra
|
|||
// migrations may not have a corresponding file on disk. Which is fine! We include them
|
||||
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
|
||||
// the SQL migration files.
|
||||
for _, r := range registerd {
|
||||
for version, r := range registerd {
|
||||
var fullpath string
|
||||
if s := sources.lookup(TypeGo, version); s != nil {
|
||||
fullpath = s.Fullpath
|
||||
}
|
||||
// Ensure there are no duplicate versions.
|
||||
if existing, ok := migrationLookup[r.Version]; ok {
|
||||
if existing, ok := migrationLookup[version]; ok {
|
||||
if fullpath == "" {
|
||||
fullpath = "manually registered (no source)"
|
||||
}
|
||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||
r.Version,
|
||||
existing,
|
||||
filepath.Base(r.Source),
|
||||
version,
|
||||
existing.Source.Fullpath,
|
||||
fullpath,
|
||||
)
|
||||
}
|
||||
m := &migrate.Migration{
|
||||
Fullpath: r.Source, // May be empty if the migration was registered manually.
|
||||
Version: r.Version,
|
||||
Type: migrate.TypeGo,
|
||||
Go: &migrate.Go{
|
||||
UseTx: r.UseTx,
|
||||
UpFn: r.UpFnContext,
|
||||
UpFnNoTx: r.UpFnNoTxContext,
|
||||
DownFn: r.DownFnContext,
|
||||
DownFnNoTx: r.DownFnNoTxContext,
|
||||
},
|
||||
m := &migration{
|
||||
// Note, the fullpath may be empty if the migration was registered manually.
|
||||
Source: newSource(TypeGo, fullpath, version),
|
||||
Go: r,
|
||||
}
|
||||
migrations = append(migrations, m)
|
||||
migrationLookup[r.Version] = m
|
||||
migrationLookup[version] = m
|
||||
}
|
||||
// Sort migrations by version in ascending order.
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
return migrations[i].Source.Version < migrations[j].Source.Version
|
||||
})
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func unregisteredError(unregistered []string) error {
|
||||
const (
|
||||
hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations"
|
||||
)
|
||||
f := "file"
|
||||
if len(unregistered) > 1 {
|
||||
f += "s"
|
||||
|
@ -169,8 +205,9 @@ func unregisteredError(unregistered []string) error {
|
|||
for _, name := range unregistered {
|
||||
b.WriteString("\t" + name + "\n")
|
||||
}
|
||||
hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL)
|
||||
b.WriteString(hint)
|
||||
b.WriteString("\n")
|
||||
b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations")
|
||||
|
||||
return errors.New(b.String())
|
||||
}
|
||||
|
|
|
@ -10,14 +10,14 @@ import (
|
|||
|
||||
func TestCollectFileSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
t.Run("nil_fsys", func(t *testing.T) {
|
||||
sources, err := collectFileSources(nil, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Bool(t, sources != nil, true)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
})
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
t.Run("empty_fsys", func(t *testing.T) {
|
||||
sources, err := collectFileSources(fstest.MapFS{}, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
|
@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) {
|
|||
check.Number(t, len(sources.goSources), 0)
|
||||
expected := fileSources{
|
||||
sqlSources: []Source{
|
||||
{Fullpath: "00001_foo.sql", Version: 1},
|
||||
{Fullpath: "00002_bar.sql", Version: 2},
|
||||
{Fullpath: "00003_baz.sql", Version: 3},
|
||||
{Fullpath: "00110_qux.sql", Version: 110},
|
||||
newSource(TypeSQL, "00001_foo.sql", 1),
|
||||
newSource(TypeSQL, "00002_bar.sql", 2),
|
||||
newSource(TypeSQL, "00003_baz.sql", 3),
|
||||
newSource(TypeSQL, "00110_qux.sql", 110),
|
||||
},
|
||||
}
|
||||
for i := 0; i < len(sources.sqlSources); i++ {
|
||||
|
@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) {
|
|||
check.Number(t, len(sources.goSources), 0)
|
||||
expected := fileSources{
|
||||
sqlSources: []Source{
|
||||
{Fullpath: "00001_foo.sql", Version: 1},
|
||||
{Fullpath: "00003_baz.sql", Version: 3},
|
||||
newSource(TypeSQL, "00001_foo.sql", 1),
|
||||
newSource(TypeSQL, "00003_baz.sql", 3),
|
||||
},
|
||||
}
|
||||
for i := 0; i < len(sources.sqlSources); i++ {
|
||||
|
@ -159,18 +159,146 @@ func TestCollectFileSources(t *testing.T) {
|
|||
}
|
||||
}
|
||||
assertDirpath(".", []Source{
|
||||
{Fullpath: "876_a.sql", Version: 876},
|
||||
newSource(TypeSQL, "876_a.sql", 876),
|
||||
})
|
||||
assertDirpath("dir1", []Source{
|
||||
{Fullpath: "101_a.sql", Version: 101},
|
||||
{Fullpath: "102_b.sql", Version: 102},
|
||||
{Fullpath: "103_c.sql", Version: 103},
|
||||
newSource(TypeSQL, "101_a.sql", 101),
|
||||
newSource(TypeSQL, "102_b.sql", 102),
|
||||
newSource(TypeSQL, "103_c.sql", 103),
|
||||
})
|
||||
assertDirpath("dir2", []Source{
|
||||
newSource(TypeSQL, "201_a.sql", 201),
|
||||
})
|
||||
assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}})
|
||||
assertDirpath("dir3", nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("with_go_files_on_disk", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
// SQL
|
||||
"migrations/00001_foo.sql": sqlMapFile,
|
||||
// Go
|
||||
"migrations/00002_bar.go": {Data: []byte(`package migrations`)},
|
||||
"migrations/00003_baz.go": {Data: []byte(`package migrations`)},
|
||||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFileSources(fsys, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Equal(t, len(sources.sqlSources), 1)
|
||||
check.Equal(t, len(sources.goSources), 2)
|
||||
src1 := sources.lookup(TypeSQL, 1)
|
||||
check.Bool(t, src1 != nil, true)
|
||||
src2 := sources.lookup(TypeGo, 2)
|
||||
check.Bool(t, src2 != nil, true)
|
||||
src3 := sources.lookup(TypeGo, 3)
|
||||
check.Bool(t, src3 != nil, true)
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
migrations, err := merge(sources, map[int64]*goMigration{
|
||||
2: {version: 2},
|
||||
3: {version: 3},
|
||||
})
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(migrations), 3)
|
||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
||||
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
|
||||
})
|
||||
t.Run("unregistered_all", func(t *testing.T) {
|
||||
_, err := merge(sources, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:")
|
||||
check.Contains(t, err.Error(), "00002_bar.go")
|
||||
check.Contains(t, err.Error(), "00003_baz.go")
|
||||
})
|
||||
t.Run("unregistered_some", func(t *testing.T) {
|
||||
_, err := merge(sources, map[int64]*goMigration{
|
||||
2: {version: 2},
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
||||
check.Contains(t, err.Error(), "00003_baz.go")
|
||||
})
|
||||
t.Run("duplicate_sql", func(t *testing.T) {
|
||||
_, err := merge(sources, map[int64]*goMigration{
|
||||
1: {version: 1}, // duplicate. SQL already exists.
|
||||
2: {version: 2},
|
||||
3: {version: 3},
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||
})
|
||||
})
|
||||
t.Run("no_go_files_on_disk", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
// SQL
|
||||
"migrations/00001_foo.sql": sqlMapFile,
|
||||
"migrations/00002_bar.sql": sqlMapFile,
|
||||
"migrations/00005_baz.sql": sqlMapFile,
|
||||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFileSources(fsys, false, nil)
|
||||
check.NoError(t, err)
|
||||
t.Run("unregistered_all", func(t *testing.T) {
|
||||
migrations, err := merge(sources, map[int64]*goMigration{
|
||||
3: {version: 3},
|
||||
// 4 is missing
|
||||
6: {version: 6},
|
||||
})
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(migrations), 5)
|
||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
|
||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||
assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5))
|
||||
assertMigration(t, migrations[4], newSource(TypeGo, "", 6))
|
||||
})
|
||||
})
|
||||
t.Run("partial_go_files_on_disk", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"migrations/00001_foo.sql": sqlMapFile,
|
||||
"migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)},
|
||||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFileSources(fsys, false, nil)
|
||||
check.NoError(t, err)
|
||||
t.Run("unregistered_all", func(t *testing.T) {
|
||||
migrations, err := merge(sources, map[int64]*goMigration{
|
||||
// This is the only Go file on disk.
|
||||
2: {version: 2},
|
||||
// These are not on disk. Explicitly registered.
|
||||
3: {version: 3},
|
||||
6: {version: 6},
|
||||
})
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(migrations), 4)
|
||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func assertMigration(t *testing.T, got *migration, want Source) {
|
||||
t.Helper()
|
||||
check.Equal(t, got.Source, want)
|
||||
switch got.Source.Type {
|
||||
case TypeGo:
|
||||
check.Equal(t, got.Go.version, want.Version)
|
||||
case TypeSQL:
|
||||
check.Bool(t, got.SQL == nil, true)
|
||||
default:
|
||||
t.Fatalf("unknown migration type: %s", got.Source.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func newSQLOnlyFS() fstest.MapFS {
|
||||
return fstest.MapFS{
|
||||
"migrations/00001_foo.sql": sqlMapFile,
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlextended"
|
||||
)
|
||||
|
||||
type migration struct {
|
||||
Source Source
|
||||
// A migration is either a Go migration or a SQL migration, but never both.
|
||||
//
|
||||
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
|
||||
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
|
||||
// majority of the time migrations are incremental, so it is likely that the user will only want
|
||||
// to run the last few migrations, and there is no need to parse ALL prior migrations.
|
||||
//
|
||||
// Exactly one of these fields will be set:
|
||||
Go *goMigration
|
||||
// -- OR --
|
||||
SQL *sqlMigration
|
||||
}
|
||||
|
||||
type MigrationType int
|
||||
|
||||
const (
|
||||
TypeGo MigrationType = iota + 1
|
||||
TypeSQL
|
||||
)
|
||||
|
||||
func (t MigrationType) String() string {
|
||||
switch t {
|
||||
case TypeGo:
|
||||
return "go"
|
||||
case TypeSQL:
|
||||
return "sql"
|
||||
default:
|
||||
// This should never happen.
|
||||
return fmt.Sprintf("unknown (%d)", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *migration) GetSQLStatements(direction bool) ([]string, error) {
|
||||
if m.Source.Type != TypeSQL {
|
||||
return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Source.Type)
|
||||
}
|
||||
if m.SQL == nil {
|
||||
return nil, errors.New("sql migration has not been parsed")
|
||||
}
|
||||
if direction {
|
||||
return m.SQL.UpStatements, nil
|
||||
}
|
||||
return m.SQL.DownStatements, nil
|
||||
}
|
||||
|
||||
func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
var fn func(context.Context, *sql.Tx) error
|
||||
if direction && g.up != nil {
|
||||
fn = g.up.Run
|
||||
}
|
||||
if !direction && g.down != nil {
|
||||
fn = g.down.Run
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, tx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
var fn func(context.Context, *sql.DB) error
|
||||
if direction && g.up != nil {
|
||||
fn = g.up.RunNoTx
|
||||
}
|
||||
if !direction && g.down != nil {
|
||||
fn = g.down.RunNoTx
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type sqlMigration struct {
|
||||
UseTx bool
|
||||
UpStatements []string
|
||||
DownStatements []string
|
||||
}
|
||||
|
||||
func (s *sqlMigration) IsEmpty(direction bool) bool {
|
||||
if direction {
|
||||
return len(s.UpStatements) == 0
|
||||
}
|
||||
return len(s.DownStatements) == 0
|
||||
}
|
||||
|
||||
func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
|
||||
var statements []string
|
||||
if direction {
|
||||
statements = s.UpStatements
|
||||
} else {
|
||||
statements = s.DownStatements
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -4,12 +4,14 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/migrate"
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/internal/sqladapter"
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -17,6 +19,8 @@ var (
|
|||
ErrNoMigrations = errors.New("no migrations found")
|
||||
)
|
||||
|
||||
var registeredGoMigrations = make(map[int64]*goose.Migration)
|
||||
|
||||
// NewProvider returns a new goose Provider.
|
||||
//
|
||||
// The caller is responsible for matching the database dialect with the database/sql driver. For
|
||||
|
@ -68,7 +72,59 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrations, err := merge(sources, nil)
|
||||
//
|
||||
// TODO(mf): move the merging of Go migrations into a separate function.
|
||||
//
|
||||
registered := make(map[int64]*goMigration)
|
||||
// Add user-registered Go migrations.
|
||||
for version, m := range cfg.registered {
|
||||
registered[version] = &goMigration{
|
||||
version: version,
|
||||
up: m.up,
|
||||
down: m.down,
|
||||
}
|
||||
}
|
||||
// Add init() functions. This is a bit ugly because we need to convert from the old Migration
|
||||
// struct to the new goMigration struct.
|
||||
for version, m := range registeredGoMigrations {
|
||||
if _, ok := registered[version]; ok {
|
||||
return nil, fmt.Errorf("go migration with version %d already registered", version)
|
||||
}
|
||||
g := &goMigration{
|
||||
version: version,
|
||||
}
|
||||
if m == nil {
|
||||
return nil, errors.New("registered migration with nil init function")
|
||||
}
|
||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||
return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext")
|
||||
}
|
||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
||||
return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext")
|
||||
}
|
||||
// Up
|
||||
if m.UpFnContext != nil {
|
||||
g.up = &GoMigration{
|
||||
Run: m.UpFnContext,
|
||||
}
|
||||
} else if m.UpFnNoTxContext != nil {
|
||||
g.up = &GoMigration{
|
||||
RunNoTx: m.UpFnNoTxContext,
|
||||
}
|
||||
}
|
||||
// Down
|
||||
if m.DownFnContext != nil {
|
||||
g.down = &GoMigration{
|
||||
Run: m.DownFnContext,
|
||||
}
|
||||
} else if m.DownFnNoTxContext != nil {
|
||||
g.down = &GoMigration{
|
||||
RunNoTx: m.DownFnNoTxContext,
|
||||
}
|
||||
}
|
||||
registered[version] = g
|
||||
}
|
||||
migrations, err := merge(sources, registered)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -98,7 +154,7 @@ type Provider struct {
|
|||
fsys fs.FS
|
||||
cfg config
|
||||
store sqladapter.Store
|
||||
migrations []*migrate.Migration
|
||||
migrations []*migration
|
||||
}
|
||||
|
||||
// State represents the state of a migration.
|
||||
|
@ -137,48 +193,12 @@ func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
|||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// SourceType represents the type of migration source.
|
||||
type SourceType string
|
||||
|
||||
const (
|
||||
// SourceTypeSQL represents a SQL migration.
|
||||
SourceTypeSQL SourceType = "sql"
|
||||
// SourceTypeGo represents a Go migration.
|
||||
SourceTypeGo SourceType = "go"
|
||||
)
|
||||
|
||||
// Source represents a single migration source.
|
||||
//
|
||||
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
|
||||
// the migration has a corresponding file on disk. It will be empty if the migration was registered
|
||||
// manually.
|
||||
type Source struct {
|
||||
// Type is the type of migration.
|
||||
Type SourceType
|
||||
// Full path to the migration file.
|
||||
//
|
||||
// Example: /path/to/migrations/001_create_users_table.sql
|
||||
Fullpath string
|
||||
// Version is the version of the migration.
|
||||
Version int64
|
||||
}
|
||||
|
||||
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
|
||||
// ascending order by version.
|
||||
func (p *Provider) ListSources() []*Source {
|
||||
sources := make([]*Source, 0, len(p.migrations))
|
||||
for _, m := range p.migrations {
|
||||
s := &Source{
|
||||
Fullpath: m.Fullpath,
|
||||
Version: m.Version,
|
||||
}
|
||||
switch m.Type {
|
||||
case migrate.TypeSQL:
|
||||
s.Type = SourceTypeSQL
|
||||
case migrate.TypeGo:
|
||||
s.Type = SourceTypeGo
|
||||
}
|
||||
sources = append(sources, s)
|
||||
sources = append(sources, &m.Source)
|
||||
}
|
||||
return sources
|
||||
}
|
||||
|
@ -240,3 +260,25 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
|||
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
|
||||
// will not be parsed again.
|
||||
//
|
||||
// Important: This function will mutate SQL migrations and is not safe for concurrent use.
|
||||
func ParseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
|
||||
for _, m := range migrations {
|
||||
// If the migration is a SQL migration, and it has not been parsed, parse it.
|
||||
if m.Source.Type == TypeSQL && m.SQL == nil {
|
||||
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.SQL = &sqlMigration{
|
||||
UseTx: parsed.UseTx,
|
||||
UpStatements: parsed.Up,
|
||||
DownStatements: parsed.Down,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -72,11 +74,67 @@ func WithExcludes(excludes []string) ProviderOption {
|
|||
})
|
||||
}
|
||||
|
||||
// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration].
|
||||
type GoMigration struct {
|
||||
// One of the following must be set:
|
||||
Run func(context.Context, *sql.Tx) error
|
||||
// -- OR --
|
||||
RunNoTx func(context.Context, *sql.DB) error
|
||||
}
|
||||
|
||||
// WithGoMigration registers a Go migration with the given version.
|
||||
//
|
||||
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
|
||||
// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set.
|
||||
func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
if version < 1 {
|
||||
return fmt.Errorf("go migration version must be greater than 0")
|
||||
}
|
||||
if _, ok := c.registered[version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", version)
|
||||
}
|
||||
// Allow nil up/down functions. This enables users to apply "no-op" migrations, while
|
||||
// versioning them.
|
||||
if up != nil {
|
||||
if up.Run == nil && up.RunNoTx == nil {
|
||||
return fmt.Errorf("go migration with version %d must have an up function", version)
|
||||
}
|
||||
if up.Run != nil && up.RunNoTx != nil {
|
||||
return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version)
|
||||
}
|
||||
}
|
||||
if down != nil {
|
||||
if down.Run == nil && down.RunNoTx == nil {
|
||||
return fmt.Errorf("go migration with version %d must have a down function", version)
|
||||
}
|
||||
if down.Run != nil && down.RunNoTx != nil {
|
||||
return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version)
|
||||
}
|
||||
}
|
||||
c.registered[version] = &goMigration{
|
||||
version: version,
|
||||
up: up,
|
||||
down: down,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type goMigration struct {
|
||||
version int64
|
||||
up, down *GoMigration
|
||||
}
|
||||
|
||||
type config struct {
|
||||
tableName string
|
||||
verbose bool
|
||||
excludes map[string]bool
|
||||
|
||||
// Go migrations registered by the user. These will be merged/resolved with migrations from the
|
||||
// filesystem and init() functions.
|
||||
registered map[int64]*goMigration
|
||||
|
||||
// Locking options
|
||||
lockEnabled bool
|
||||
sessionLocker lock.SessionLocker
|
||||
|
|
|
@ -36,11 +36,11 @@ func TestProvider(t *testing.T) {
|
|||
// 1
|
||||
check.Equal(t, sources[0].Version, int64(1))
|
||||
check.Equal(t, sources[0].Fullpath, "001_foo.sql")
|
||||
check.Equal(t, sources[0].Type, provider.SourceTypeSQL)
|
||||
check.Equal(t, sources[0].Type, provider.TypeSQL)
|
||||
// 2
|
||||
check.Equal(t, sources[1].Version, int64(2))
|
||||
check.Equal(t, sources[1].Fullpath, "002_bar.sql")
|
||||
check.Equal(t, sources[1].Type, provider.SourceTypeSQL)
|
||||
check.Equal(t, sources[1].Type, provider.TypeSQL)
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package migrate
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -8,10 +8,10 @@ import (
|
|||
)
|
||||
|
||||
// Run runs the migration inside of a transaction.
|
||||
func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
switch m.Type {
|
||||
func (m *migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil || !m.SQLParsed {
|
||||
if m.SQL == nil {
|
||||
return fmt.Errorf("tx: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, tx, direction)
|
||||
|
@ -19,14 +19,14 @@ func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
|||
return m.Go.run(ctx, tx, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
|
||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
|
||||
}
|
||||
|
||||
// RunNoTx runs the migration without a transaction.
|
||||
func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
switch m.Type {
|
||||
func (m *migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil || !m.SQLParsed {
|
||||
if m.SQL == nil {
|
||||
return fmt.Errorf("db: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, db, direction)
|
||||
|
@ -34,14 +34,14 @@ func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) err
|
|||
return m.Go.runNoTx(ctx, db, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
|
||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
|
||||
}
|
||||
|
||||
// RunConn runs the migration without a transaction using the provided connection.
|
||||
func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error {
|
||||
switch m.Type {
|
||||
func (m *migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil || !m.SQLParsed {
|
||||
if m.SQL == nil {
|
||||
return fmt.Errorf("conn: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, conn, direction)
|
||||
|
@ -49,5 +49,5 @@ func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool)
|
|||
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
|
||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package sqlparser
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
|
||||
"go.uber.org/multierr"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type ParsedSQL struct {
|
||||
UseTx bool
|
||||
Up, Down []string
|
||||
}
|
||||
|
||||
func ParseAllFromFS(fsys fs.FS, filename string, debug bool) (*ParsedSQL, error) {
|
||||
parsedSQL := new(ParsedSQL)
|
||||
// TODO(mf): parse is called twice, once for up and once for down. This is inefficient. It
|
||||
// should be possible to parse both directions in one pass. Also, UseTx is set once (but
|
||||
// returned twice), which is unnecessary and potentially error-prone if the two calls to
|
||||
// parseSQL disagree based on direction.
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
up, useTx, err := parse(fsys, filename, DirectionUp, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parsedSQL.Up = up
|
||||
parsedSQL.UseTx = useTx
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
down, _, err := parse(fsys, filename, DirectionDown, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parsedSQL.Down = down
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsedSQL, nil
|
||||
}
|
||||
|
||||
func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []string, _ bool, retErr error) {
|
||||
r, err := fsys.Open(filename)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, r.Close())
|
||||
}()
|
||||
return ParseSQLMigration(r, direction, debug)
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package sqlparser_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
)
|
||||
|
||||
func TestParseAllFromFS(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("file_not_exist", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{}
|
||||
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, os.ErrNotExist), true)
|
||||
})
|
||||
t.Run("empty_file", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"001_foo.sql": &fstest.MapFile{},
|
||||
}
|
||||
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "failed to parse migration")
|
||||
check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation")
|
||||
})
|
||||
t.Run("all_statements", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"001_foo.sql": newFile(`
|
||||
-- +goose Up
|
||||
`),
|
||||
"002_bar.sql": newFile(`
|
||||
-- +goose Up
|
||||
-- +goose Down
|
||||
`),
|
||||
"003_baz.sql": newFile(`
|
||||
-- +goose Up
|
||||
CREATE TABLE foo (id int);
|
||||
CREATE TABLE bar (id int);
|
||||
|
||||
-- +goose Down
|
||||
DROP TABLE bar;
|
||||
`),
|
||||
"004_qux.sql": newFile(`
|
||||
-- +goose NO TRANSACTION
|
||||
-- +goose Up
|
||||
CREATE TABLE foo (id int);
|
||||
-- +goose Down
|
||||
DROP TABLE foo;
|
||||
`),
|
||||
}
|
||||
parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
||||
check.NoError(t, err)
|
||||
assertParsedSQL(t, parsedSQL, true, 0, 0)
|
||||
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false)
|
||||
check.NoError(t, err)
|
||||
assertParsedSQL(t, parsedSQL, true, 0, 0)
|
||||
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false)
|
||||
check.NoError(t, err)
|
||||
assertParsedSQL(t, parsedSQL, true, 2, 1)
|
||||
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false)
|
||||
check.NoError(t, err)
|
||||
assertParsedSQL(t, parsedSQL, false, 1, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) {
|
||||
t.Helper()
|
||||
check.Bool(t, got != nil, true)
|
||||
check.Equal(t, len(got.Up), up)
|
||||
check.Equal(t, len(got.Down), down)
|
||||
check.Equal(t, got.UseTx, useTx)
|
||||
}
|
||||
|
||||
func newFile(data string) *fstest.MapFile {
|
||||
return &fstest.MapFile{
|
||||
Data: []byte(data),
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue