feat(experimental): shuffle packages & add explicit provider Go func registration (#616)

pull/617/head
Michael Fridman 2023-10-14 23:04:07 -04:00 committed by GitHub
parent 68853f91ea
commit 58f8534610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 629 additions and 357 deletions

1
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/vertica/vertica-sql-go v1.3.3
github.com/ziutek/mymysql v1.5.4
go.uber.org/multierr v1.11.0
golang.org/x/sync v0.4.0
modernc.org/sqlite v1.26.0
)

3
go.sum
View File

@ -184,7 +184,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -1,9 +0,0 @@
// Package migrate defines a Migration struct and implements the migration logic for executing Go
// and SQL migrations.
//
// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported.
// - For SQL migrations, all three are supported.
//
// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first
// time it is executed.
package migrate

View File

@ -1,166 +0,0 @@
package migrate
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3/internal/sqlextended"
)
type Migration struct {
// Fullpath is the full path to the migration file.
//
// Example: /path/to/migrations/123_create_users_table.go
Fullpath string
// Version is the version of the migration.
Version int64
// Type is the type of migration.
Type MigrationType
// A migration is either a Go migration or a SQL migration, but never both.
//
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
// majority of the time migrations are incremental, so it is likely that the user will only want
// to run the last few migrations, and there is no need to parse ALL prior migrations.
//
// Exactly one of these fields will be set:
Go *Go
// -- or --
SQLParsed bool
SQL *SQL
}
type MigrationType int
const (
TypeGo MigrationType = iota + 1
TypeSQL
)
func (t MigrationType) String() string {
switch t {
case TypeGo:
return "go"
case TypeSQL:
return "sql"
default:
// This should never happen.
return "unknown"
}
}
func (m *Migration) UseTx() bool {
switch m.Type {
case TypeGo:
return m.Go.UseTx
case TypeSQL:
return m.SQL.UseTx
default:
// This should never happen.
panic("unknown migration type: use tx")
}
}
func (m *Migration) IsEmpty(direction bool) bool {
switch m.Type {
case TypeGo:
return m.Go.IsEmpty(direction)
case TypeSQL:
return m.SQL.IsEmpty(direction)
default:
// This should never happen.
panic("unknown migration type: is empty")
}
}
func (m *Migration) GetSQLStatements(direction bool) ([]string, error) {
if m.Type != TypeSQL {
return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type)
}
if m.SQL == nil {
return nil, errors.New("sql migration has not been initialized")
}
if !m.SQLParsed {
return nil, errors.New("sql migration has not been parsed")
}
if direction {
return m.SQL.UpStatements, nil
}
return m.SQL.DownStatements, nil
}
type Go struct {
// We used an explicit bool instead of relying on a pointer because registered funcs may be nil.
// These are still valid Go and versioned migrations, but they are just empty.
//
// For example: goose.AddMigration(nil, nil)
UseTx bool
// Only one of these func pairs will be set:
UpFn, DownFn func(context.Context, *sql.Tx) error
// -- or --
UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error
}
func (g *Go) IsEmpty(direction bool) bool {
if direction {
return g.UpFn == nil && g.UpFnNoTx == nil
}
return g.DownFn == nil && g.DownFnNoTx == nil
}
func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error {
var fn func(context.Context, *sql.Tx) error
if direction {
fn = g.UpFn
} else {
fn = g.DownFn
}
if fn != nil {
return fn(ctx, tx)
}
return nil
}
func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
var fn func(context.Context, *sql.DB) error
if direction {
fn = g.UpFnNoTx
} else {
fn = g.DownFnNoTx
}
if fn != nil {
return fn(ctx, db)
}
return nil
}
type SQL struct {
UseTx bool
UpStatements []string
DownStatements []string
}
func (s *SQL) IsEmpty(direction bool) bool {
if direction {
return len(s.UpStatements) == 0
}
return len(s.DownStatements) == 0
}
func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
} else {
statements = s.DownStatements
}
for _, stmt := range statements {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}

View File

@ -1,75 +0,0 @@
package migrate
import (
"bytes"
"io"
"io/fs"
"github.com/pressly/goose/v3/internal/sqlparser"
)
// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
// will not be parsed again.
//
// Important: This function will mutate SQL migrations.
func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error {
for _, m := range migrations {
if m.Type == TypeSQL && !m.SQLParsed {
parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug)
if err != nil {
return err
}
m.SQLParsed = true
m.SQL = parsedSQLMigration
}
}
return nil
}
// parse is used to determine which direction to parse the SQL migration.
type parse int
const (
// parseAll parses all SQL statements in BOTH directions.
parseAll parse = iota + 1
// parseUp parses all SQL statements in the UP direction.
parseUp
// parseDown parses all SQL statements in the DOWN direction.
parseDown
)
func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) {
r, err := fsys.Open(filename)
if err != nil {
return nil, err
}
by, err := io.ReadAll(r)
if err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
s := new(SQL)
if p == parseAll || p == parseUp {
s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionUp,
debug,
)
if err != nil {
return nil, err
}
}
if p == parseAll || p == parseDown {
s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionDown,
debug,
)
if err != nil {
return nil, err
}
}
return s, nil
}

View File

@ -9,15 +9,56 @@ import (
"strings"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/migrate"
)
// Source represents a single migration source.
//
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
// the migration has a corresponding file on disk. It will be empty if the migration was registered
// manually.
type Source struct {
// Type is the type of migration.
Type MigrationType
// Full path to the migration file.
//
// Example: /path/to/migrations/001_create_users_table.sql
Fullpath string
// Version is the version of the migration.
Version int64
}
func newSource(t MigrationType, fullpath string, version int64) Source {
return Source{
Type: t,
Fullpath: fullpath,
Version: version,
}
}
// fileSources represents a collection of migration files on the filesystem.
type fileSources struct {
sqlSources []Source
goSources []Source
}
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
switch t {
case TypeGo:
for _, source := range s.goSources {
if source.Version == version {
return &source
}
}
case TypeSQL:
for _, source := range s.sqlSources {
if source.Version == version {
return &source
}
}
}
return nil
}
// collectFileSources scans the file system for migration files that have a numeric prefix (greater
// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil,
// in which case an empty fileSources is returned.
@ -69,15 +110,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, Source{
Fullpath: fullpath,
Version: version,
})
sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version))
case ".go":
sources.goSources = append(sources.goSources, Source{
Fullpath: fullpath,
Version: version,
})
sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version))
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("unknown migration type: %s", base)
@ -89,19 +124,17 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
return sources, nil
}
func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) {
var migrations []*migrate.Migration
migrationLookup := make(map[int64]*migrate.Migration)
func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) {
var migrations []*migration
migrationLookup := make(map[int64]*migration)
// Add all SQL migrations to the list of migrations.
for _, s := range sources.sqlSources {
m := &migrate.Migration{
Type: migrate.TypeSQL,
Fullpath: s.Fullpath,
Version: s.Version,
SQLParsed: false,
for _, source := range sources.sqlSources {
m := &migration{
Source: source,
SQL: nil, // SQL migrations are parsed lazily.
}
migrations = append(migrations, m)
migrationLookup[s.Version] = m
migrationLookup[source.Version] = m
}
// If there are no Go files in the filesystem and no registered Go migrations, return early.
if len(sources.goSources) == 0 && len(registerd) == 0 {
@ -127,38 +160,41 @@ func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migra
// migrations may not have a corresponding file on disk. Which is fine! We include them
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
// the SQL migration files.
for _, r := range registerd {
for version, r := range registerd {
var fullpath string
if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Fullpath
}
// Ensure there are no duplicate versions.
if existing, ok := migrationLookup[r.Version]; ok {
if existing, ok := migrationLookup[version]; ok {
if fullpath == "" {
fullpath = "manually registered (no source)"
}
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
r.Version,
existing,
filepath.Base(r.Source),
version,
existing.Source.Fullpath,
fullpath,
)
}
m := &migrate.Migration{
Fullpath: r.Source, // May be empty if the migration was registered manually.
Version: r.Version,
Type: migrate.TypeGo,
Go: &migrate.Go{
UseTx: r.UseTx,
UpFn: r.UpFnContext,
UpFnNoTx: r.UpFnNoTxContext,
DownFn: r.DownFnContext,
DownFnNoTx: r.DownFnNoTxContext,
},
m := &migration{
// Note, the fullpath may be empty if the migration was registered manually.
Source: newSource(TypeGo, fullpath, version),
Go: r,
}
migrations = append(migrations, m)
migrationLookup[r.Version] = m
migrationLookup[version] = m
}
// Sort migrations by version in ascending order.
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
return migrations[i].Source.Version < migrations[j].Source.Version
})
return migrations, nil
}
func unregisteredError(unregistered []string) error {
const (
hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations"
)
f := "file"
if len(unregistered) > 1 {
f += "s"
@ -169,8 +205,9 @@ func unregisteredError(unregistered []string) error {
for _, name := range unregistered {
b.WriteString("\t" + name + "\n")
}
hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL)
b.WriteString(hint)
b.WriteString("\n")
b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations")
return errors.New(b.String())
}

View File

@ -10,14 +10,14 @@ import (
func TestCollectFileSources(t *testing.T) {
t.Parallel()
t.Run("nil", func(t *testing.T) {
t.Run("nil_fsys", func(t *testing.T) {
sources, err := collectFileSources(nil, false, nil)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("empty", func(t *testing.T) {
t.Run("empty_fsys", func(t *testing.T) {
sources, err := collectFileSources(fstest.MapFS{}, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
{Fullpath: "00001_foo.sql", Version: 1},
{Fullpath: "00002_bar.sql", Version: 2},
{Fullpath: "00003_baz.sql", Version: 3},
{Fullpath: "00110_qux.sql", Version: 110},
newSource(TypeSQL, "00001_foo.sql", 1),
newSource(TypeSQL, "00002_bar.sql", 2),
newSource(TypeSQL, "00003_baz.sql", 3),
newSource(TypeSQL, "00110_qux.sql", 110),
},
}
for i := 0; i < len(sources.sqlSources); i++ {
@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
{Fullpath: "00001_foo.sql", Version: 1},
{Fullpath: "00003_baz.sql", Version: 3},
newSource(TypeSQL, "00001_foo.sql", 1),
newSource(TypeSQL, "00003_baz.sql", 3),
},
}
for i := 0; i < len(sources.sqlSources); i++ {
@ -159,18 +159,146 @@ func TestCollectFileSources(t *testing.T) {
}
}
assertDirpath(".", []Source{
{Fullpath: "876_a.sql", Version: 876},
newSource(TypeSQL, "876_a.sql", 876),
})
assertDirpath("dir1", []Source{
{Fullpath: "101_a.sql", Version: 101},
{Fullpath: "102_b.sql", Version: 102},
{Fullpath: "103_c.sql", Version: 103},
newSource(TypeSQL, "101_a.sql", 101),
newSource(TypeSQL, "102_b.sql", 102),
newSource(TypeSQL, "103_c.sql", 103),
})
assertDirpath("dir2", []Source{
newSource(TypeSQL, "201_a.sql", 201),
})
assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}})
assertDirpath("dir3", nil)
})
}
func TestMerge(t *testing.T) {
t.Parallel()
t.Run("with_go_files_on_disk", func(t *testing.T) {
mapFS := fstest.MapFS{
// SQL
"migrations/00001_foo.sql": sqlMapFile,
// Go
"migrations/00002_bar.go": {Data: []byte(`package migrations`)},
"migrations/00003_baz.go": {Data: []byte(`package migrations`)},
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 1)
check.Equal(t, len(sources.goSources), 2)
src1 := sources.lookup(TypeSQL, 1)
check.Bool(t, src1 != nil, true)
src2 := sources.lookup(TypeGo, 2)
check.Bool(t, src2 != nil, true)
src3 := sources.lookup(TypeGo, 3)
check.Bool(t, src3 != nil, true)
t.Run("valid", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
2: {version: 2},
3: {version: 3},
})
check.NoError(t, err)
check.Number(t, len(migrations), 3)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
})
t.Run("unregistered_all", func(t *testing.T) {
_, err := merge(sources, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:")
check.Contains(t, err.Error(), "00002_bar.go")
check.Contains(t, err.Error(), "00003_baz.go")
})
t.Run("unregistered_some", func(t *testing.T) {
_, err := merge(sources, map[int64]*goMigration{
2: {version: 2},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
check.Contains(t, err.Error(), "00003_baz.go")
})
t.Run("duplicate_sql", func(t *testing.T) {
_, err := merge(sources, map[int64]*goMigration{
1: {version: 1}, // duplicate. SQL already exists.
2: {version: 2},
3: {version: 3},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
})
})
t.Run("no_go_files_on_disk", func(t *testing.T) {
mapFS := fstest.MapFS{
// SQL
"migrations/00001_foo.sql": sqlMapFile,
"migrations/00002_bar.sql": sqlMapFile,
"migrations/00005_baz.sql": sqlMapFile,
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
3: {version: 3},
// 4 is missing
6: {version: 6},
})
check.NoError(t, err)
check.Number(t, len(migrations), 5)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5))
assertMigration(t, migrations[4], newSource(TypeGo, "", 6))
})
})
t.Run("partial_go_files_on_disk", func(t *testing.T) {
mapFS := fstest.MapFS{
"migrations/00001_foo.sql": sqlMapFile,
"migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)},
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
// This is the only Go file on disk.
2: {version: 2},
// These are not on disk. Explicitly registered.
3: {version: 3},
6: {version: 6},
})
check.NoError(t, err)
check.Number(t, len(migrations), 4)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
})
})
}
func assertMigration(t *testing.T, got *migration, want Source) {
t.Helper()
check.Equal(t, got.Source, want)
switch got.Source.Type {
case TypeGo:
check.Equal(t, got.Go.version, want.Version)
case TypeSQL:
check.Bool(t, got.SQL == nil, true)
default:
t.Fatalf("unknown migration type: %s", got.Source.Type)
}
}
func newSQLOnlyFS() fstest.MapFS {
return fstest.MapFS{
"migrations/00001_foo.sql": sqlMapFile,

View File

@ -0,0 +1,119 @@
package provider
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3/internal/sqlextended"
)
type migration struct {
Source Source
// A migration is either a Go migration or a SQL migration, but never both.
//
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
// majority of the time migrations are incremental, so it is likely that the user will only want
// to run the last few migrations, and there is no need to parse ALL prior migrations.
//
// Exactly one of these fields will be set:
Go *goMigration
// -- OR --
SQL *sqlMigration
}
type MigrationType int
const (
TypeGo MigrationType = iota + 1
TypeSQL
)
func (t MigrationType) String() string {
switch t {
case TypeGo:
return "go"
case TypeSQL:
return "sql"
default:
// This should never happen.
return fmt.Sprintf("unknown (%d)", t)
}
}
func (m *migration) GetSQLStatements(direction bool) ([]string, error) {
if m.Source.Type != TypeSQL {
return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Source.Type)
}
if m.SQL == nil {
return nil, errors.New("sql migration has not been parsed")
}
if direction {
return m.SQL.UpStatements, nil
}
return m.SQL.DownStatements, nil
}
func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
if g == nil {
return nil
}
var fn func(context.Context, *sql.Tx) error
if direction && g.up != nil {
fn = g.up.Run
}
if !direction && g.down != nil {
fn = g.down.Run
}
if fn != nil {
return fn(ctx, tx)
}
return nil
}
func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
if g == nil {
return nil
}
var fn func(context.Context, *sql.DB) error
if direction && g.up != nil {
fn = g.up.RunNoTx
}
if !direction && g.down != nil {
fn = g.down.RunNoTx
}
if fn != nil {
return fn(ctx, db)
}
return nil
}
type sqlMigration struct {
UseTx bool
UpStatements []string
DownStatements []string
}
func (s *sqlMigration) IsEmpty(direction bool) bool {
if direction {
return len(s.UpStatements) == 0
}
return len(s.DownStatements) == 0
}
func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
} else {
statements = s.DownStatements
}
for _, stmt := range statements {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}

View File

@ -4,12 +4,14 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"os"
"time"
"github.com/pressly/goose/v3/internal/migrate"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/internal/sqlparser"
)
var (
@ -17,6 +19,8 @@ var (
ErrNoMigrations = errors.New("no migrations found")
)
var registeredGoMigrations = make(map[int64]*goose.Migration)
// NewProvider returns a new goose Provider.
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
@ -68,7 +72,59 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
if err != nil {
return nil, err
}
migrations, err := merge(sources, nil)
//
// TODO(mf): move the merging of Go migrations into a separate function.
//
registered := make(map[int64]*goMigration)
// Add user-registered Go migrations.
for version, m := range cfg.registered {
registered[version] = &goMigration{
version: version,
up: m.up,
down: m.down,
}
}
// Add init() functions. This is a bit ugly because we need to convert from the old Migration
// struct to the new goMigration struct.
for version, m := range registeredGoMigrations {
if _, ok := registered[version]; ok {
return nil, fmt.Errorf("go migration with version %d already registered", version)
}
g := &goMigration{
version: version,
}
if m == nil {
return nil, errors.New("registered migration with nil init function")
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext")
}
// Up
if m.UpFnContext != nil {
g.up = &GoMigration{
Run: m.UpFnContext,
}
} else if m.UpFnNoTxContext != nil {
g.up = &GoMigration{
RunNoTx: m.UpFnNoTxContext,
}
}
// Down
if m.DownFnContext != nil {
g.down = &GoMigration{
Run: m.DownFnContext,
}
} else if m.DownFnNoTxContext != nil {
g.down = &GoMigration{
RunNoTx: m.DownFnNoTxContext,
}
}
registered[version] = g
}
migrations, err := merge(sources, registered)
if err != nil {
return nil, err
}
@ -98,7 +154,7 @@ type Provider struct {
fsys fs.FS
cfg config
store sqladapter.Store
migrations []*migrate.Migration
migrations []*migration
}
// State represents the state of a migration.
@ -137,48 +193,12 @@ func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
// SourceType represents the type of migration source.
type SourceType string
const (
// SourceTypeSQL represents a SQL migration.
SourceTypeSQL SourceType = "sql"
// SourceTypeGo represents a Go migration.
SourceTypeGo SourceType = "go"
)
// Source represents a single migration source.
//
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
// the migration has a corresponding file on disk. It will be empty if the migration was registered
// manually.
type Source struct {
// Type is the type of migration.
Type SourceType
// Full path to the migration file.
//
// Example: /path/to/migrations/001_create_users_table.sql
Fullpath string
// Version is the version of the migration.
Version int64
}
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
// ascending order by version.
func (p *Provider) ListSources() []*Source {
sources := make([]*Source, 0, len(p.migrations))
for _, m := range p.migrations {
s := &Source{
Fullpath: m.Fullpath,
Version: m.Version,
}
switch m.Type {
case migrate.TypeSQL:
s.Type = SourceTypeSQL
case migrate.TypeGo:
s.Type = SourceTypeGo
}
sources = append(sources, s)
sources = append(sources, &m.Source)
}
return sources
}
@ -240,3 +260,25 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
return nil, errors.New("not implemented")
}
// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
// will not be parsed again.
//
// Important: This function will mutate SQL migrations and is not safe for concurrent use.
func ParseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
for _, m := range migrations {
// If the migration is a SQL migration, and it has not been parsed, parse it.
if m.Source.Type == TypeSQL && m.SQL == nil {
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug)
if err != nil {
return err
}
m.SQL = &sqlMigration{
UseTx: parsed.UseTx,
UpStatements: parsed.Up,
DownStatements: parsed.Down,
}
}
}
return nil
}

View File

@ -1,6 +1,8 @@
package provider
import (
"context"
"database/sql"
"errors"
"fmt"
@ -72,11 +74,67 @@ func WithExcludes(excludes []string) ProviderOption {
})
}
// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration].
type GoMigration struct {
// One of the following must be set:
Run func(context.Context, *sql.Tx) error
// -- OR --
RunNoTx func(context.Context, *sql.DB) error
}
// WithGoMigration registers a Go migration with the given version.
//
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set.
func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
return configFunc(func(c *config) error {
if version < 1 {
return fmt.Errorf("go migration version must be greater than 0")
}
if _, ok := c.registered[version]; ok {
return fmt.Errorf("go migration with version %d already registered", version)
}
// Allow nil up/down functions. This enables users to apply "no-op" migrations, while
// versioning them.
if up != nil {
if up.Run == nil && up.RunNoTx == nil {
return fmt.Errorf("go migration with version %d must have an up function", version)
}
if up.Run != nil && up.RunNoTx != nil {
return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version)
}
}
if down != nil {
if down.Run == nil && down.RunNoTx == nil {
return fmt.Errorf("go migration with version %d must have a down function", version)
}
if down.Run != nil && down.RunNoTx != nil {
return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version)
}
}
c.registered[version] = &goMigration{
version: version,
up: up,
down: down,
}
return nil
})
}
type goMigration struct {
version int64
up, down *GoMigration
}
type config struct {
tableName string
verbose bool
excludes map[string]bool
// Go migrations registered by the user. These will be merged/resolved with migrations from the
// filesystem and init() functions.
registered map[int64]*goMigration
// Locking options
lockEnabled bool
sessionLocker lock.SessionLocker

View File

@ -36,11 +36,11 @@ func TestProvider(t *testing.T) {
// 1
check.Equal(t, sources[0].Version, int64(1))
check.Equal(t, sources[0].Fullpath, "001_foo.sql")
check.Equal(t, sources[0].Type, provider.SourceTypeSQL)
check.Equal(t, sources[0].Type, provider.TypeSQL)
// 2
check.Equal(t, sources[1].Version, int64(2))
check.Equal(t, sources[1].Fullpath, "002_bar.sql")
check.Equal(t, sources[1].Type, provider.SourceTypeSQL)
check.Equal(t, sources[1].Type, provider.TypeSQL)
}
var (

View File

@ -1,4 +1,4 @@
package migrate
package provider
import (
"context"
@ -8,10 +8,10 @@ import (
)
// Run runs the migration inside of a transaction.
func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
switch m.Type {
func (m *migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil || !m.SQLParsed {
if m.SQL == nil {
return fmt.Errorf("tx: sql migration has not been parsed")
}
return m.SQL.run(ctx, tx, direction)
@ -19,14 +19,14 @@ func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
return m.Go.run(ctx, tx, direction)
}
// This should never happen.
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
}
// RunNoTx runs the migration without a transaction.
func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error {
switch m.Type {
func (m *migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil || !m.SQLParsed {
if m.SQL == nil {
return fmt.Errorf("db: sql migration has not been parsed")
}
return m.SQL.run(ctx, db, direction)
@ -34,14 +34,14 @@ func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) err
return m.Go.runNoTx(ctx, db, direction)
}
// This should never happen.
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
}
// RunConn runs the migration without a transaction using the provided connection.
func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error {
switch m.Type {
func (m *migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil || !m.SQLParsed {
if m.SQL == nil {
return fmt.Errorf("conn: sql migration has not been parsed")
}
return m.SQL.run(ctx, conn, direction)
@ -49,5 +49,5 @@ func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool)
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
}
// This should never happen.
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath))
}

View File

@ -0,0 +1,54 @@
package sqlparser
import (
"io/fs"
"go.uber.org/multierr"
"golang.org/x/sync/errgroup"
)
type ParsedSQL struct {
UseTx bool
Up, Down []string
}
func ParseAllFromFS(fsys fs.FS, filename string, debug bool) (*ParsedSQL, error) {
parsedSQL := new(ParsedSQL)
// TODO(mf): parse is called twice, once for up and once for down. This is inefficient. It
// should be possible to parse both directions in one pass. Also, UseTx is set once (but
// returned twice), which is unnecessary and potentially error-prone if the two calls to
// parseSQL disagree based on direction.
var g errgroup.Group
g.Go(func() error {
up, useTx, err := parse(fsys, filename, DirectionUp, debug)
if err != nil {
return err
}
parsedSQL.Up = up
parsedSQL.UseTx = useTx
return nil
})
g.Go(func() error {
down, _, err := parse(fsys, filename, DirectionDown, debug)
if err != nil {
return err
}
parsedSQL.Down = down
return nil
})
if err := g.Wait(); err != nil {
return nil, err
}
return parsedSQL, nil
}
func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []string, _ bool, retErr error) {
r, err := fsys.Open(filename)
if err != nil {
return nil, false, err
}
defer func() {
retErr = multierr.Append(retErr, r.Close())
}()
return ParseSQLMigration(r, direction, debug)
}

View File

@ -0,0 +1,82 @@
package sqlparser_test
import (
"errors"
"os"
"testing"
"testing/fstest"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqlparser"
)
func TestParseAllFromFS(t *testing.T) {
t.Parallel()
t.Run("file_not_exist", func(t *testing.T) {
mapFS := fstest.MapFS{}
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
check.HasError(t, err)
check.Bool(t, errors.Is(err, os.ErrNotExist), true)
})
t.Run("empty_file", func(t *testing.T) {
mapFS := fstest.MapFS{
"001_foo.sql": &fstest.MapFile{},
}
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
check.HasError(t, err)
check.Contains(t, err.Error(), "failed to parse migration")
check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation")
})
t.Run("all_statements", func(t *testing.T) {
mapFS := fstest.MapFS{
"001_foo.sql": newFile(`
-- +goose Up
`),
"002_bar.sql": newFile(`
-- +goose Up
-- +goose Down
`),
"003_baz.sql": newFile(`
-- +goose Up
CREATE TABLE foo (id int);
CREATE TABLE bar (id int);
-- +goose Down
DROP TABLE bar;
`),
"004_qux.sql": newFile(`
-- +goose NO TRANSACTION
-- +goose Up
CREATE TABLE foo (id int);
-- +goose Down
DROP TABLE foo;
`),
}
parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
check.NoError(t, err)
assertParsedSQL(t, parsedSQL, true, 0, 0)
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false)
check.NoError(t, err)
assertParsedSQL(t, parsedSQL, true, 0, 0)
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false)
check.NoError(t, err)
assertParsedSQL(t, parsedSQL, true, 2, 1)
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false)
check.NoError(t, err)
assertParsedSQL(t, parsedSQL, false, 1, 1)
})
}
func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) {
t.Helper()
check.Bool(t, got != nil, true)
check.Equal(t, len(got.Up), up)
check.Equal(t, len(got.Down), down)
check.Equal(t, got.UseTx, useTx)
}
func newFile(data string) *fstest.MapFile {
return &fstest.MapFile{
Data: []byte(data),
}
}