feat(experimental): prefactor provider and cleanup (#626)

pull/628/head
Michael Fridman 2023-10-29 21:55:41 -04:00 committed by GitHub
parent 20a99fa243
commit 497acb407f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 236 additions and 227 deletions

View File

@ -1,8 +0,0 @@
FROM golang:1.17-buster@sha256:3e663ba6af8281b04975b0a34a14d538cdd7d284213f83f05aaf596b80a8c725 as builder
COPY . /src
WORKDIR /src
RUN CGO_ENABLED=0 make dist
FROM scratch AS exporter
COPY --from=builder /src/bin/ /

View File

@ -8,7 +8,7 @@ import (
"strconv"
)
// Deprecated: VERSION will no longer be supported in v4.
// Deprecated: VERSION will no longer be supported in the next major release.
const VERSION = "v3.2.0"
var (

View File

@ -7,17 +7,10 @@ import (
"os"
"path/filepath"
"sort"
"strconv"
"strings"
)
func NewSource(t MigrationType, fullpath string, version int64) Source {
return Source{
Type: t,
Path: fullpath,
Version: version,
}
}
"github.com/pressly/goose/v3"
)
// fileSources represents a collection of migration files on the filesystem.
type fileSources struct {
@ -44,16 +37,16 @@ func (s *fileSources) lookup(t MigrationType, version int64) *Source {
return nil
}
// collectFileSources scans the file system for migration files that have a numeric prefix (greater
// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil,
// in which case an empty fileSources is returned.
// collectFilesystemSources scans the file system for migration files that have a numeric prefix
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
// be nil, in which case an empty fileSources is returned.
//
// If strict is true, then any error parsing the numeric component of the filename will result in an
// error. The file is skipped otherwise.
//
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
// migration sources from the filesystem.
func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
if fsys == nil {
return new(fileSources), nil
}
@ -78,7 +71,7 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := NumericComponent(base)
version, err := goose.NumericComponent(base)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
@ -95,9 +88,17 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version))
sources.sqlSources = append(sources.sqlSources, Source{
Type: TypeSQL,
Path: fullpath,
Version: version,
})
case ".go":
sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version))
sources.goSources = append(sources.goSources, Source{
Type: TypeGo,
Path: fullpath,
Version: version,
})
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("unknown migration type: %s", base)
@ -165,9 +166,12 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
)
}
m := &migration{
// Note, the fullpath may be empty if the migration was registered manually.
Source: NewSource(TypeGo, fullpath, version),
Go: r,
Source: Source{
Type: TypeGo,
Path: fullpath, // May be empty if migration was registered manually.
Version: version,
},
Go: r,
}
migrations = append(migrations, m)
migrationLookup[version] = m
@ -207,26 +211,3 @@ var _ fs.FS = noopFS{}
func (f noopFS) Open(name string) (fs.File, error) {
return nil, os.ErrNotExist
}
// NumericComponent parses the version from the migration file name.
//
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
// migration, either .sql or .go.
func NumericComponent(filename string) (int64, error) {
base := filepath.Base(filename)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("migration file does not have .sql or .go file extension")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no filename separator '_' found")
}
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, err
}
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
}
return n, nil
}

View File

@ -12,14 +12,21 @@ import (
func TestCollectFileSources(t *testing.T) {
t.Parallel()
t.Run("nil_fsys", func(t *testing.T) {
sources, err := collectFileSources(nil, false, nil)
sources, err := collectFilesystemSources(nil, false, nil)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("noop_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(noopFS{}, false, nil)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("empty_fsys", func(t *testing.T) {
sources, err := collectFileSources(fstest.MapFS{}, false, nil)
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
@ -30,28 +37,28 @@ func TestCollectFileSources(t *testing.T) {
"00000_foo.sql": sqlMapFile,
}
// strict disable - should not error
sources, err := collectFileSources(mapFS, false, nil)
sources, err := collectFilesystemSources(mapFS, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
// strict enabled - should error
_, err = collectFileSources(mapFS, true, nil)
_, err = collectFilesystemSources(mapFS, true, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration version must be greater than zero")
})
t.Run("collect", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
NewSource(TypeSQL, "00001_foo.sql", 1),
NewSource(TypeSQL, "00002_bar.sql", 2),
NewSource(TypeSQL, "00003_baz.sql", 3),
NewSource(TypeSQL, "00110_qux.sql", 110),
newSource(TypeSQL, "00001_foo.sql", 1),
newSource(TypeSQL, "00002_bar.sql", 2),
newSource(TypeSQL, "00003_baz.sql", 3),
newSource(TypeSQL, "00110_qux.sql", 110),
},
}
for i := 0; i < len(sources.sqlSources); i++ {
@ -61,7 +68,7 @@ func TestCollectFileSources(t *testing.T) {
t.Run("excludes", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFileSources(
sources, err := collectFilesystemSources(
fsys,
false,
// exclude 2 files explicitly
@ -75,8 +82,8 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
NewSource(TypeSQL, "00001_foo.sql", 1),
NewSource(TypeSQL, "00003_baz.sql", 3),
newSource(TypeSQL, "00001_foo.sql", 1),
newSource(TypeSQL, "00003_baz.sql", 3),
},
}
for i := 0; i < len(sources.sqlSources); i++ {
@ -89,7 +96,7 @@ func TestCollectFileSources(t *testing.T) {
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
_, err = collectFileSources(fsys, true, nil)
_, err = collectFilesystemSources(fsys, true, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
})
@ -101,7 +108,7 @@ func TestCollectFileSources(t *testing.T) {
"4_qux.sql": sqlMapFile,
"5_foo_test.go": {Data: []byte(`package goose_test`)},
}
sources, err := collectFileSources(mapFS, false, nil)
sources, err := collectFilesystemSources(mapFS, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
@ -116,7 +123,7 @@ func TestCollectFileSources(t *testing.T) {
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
}
sources, err := collectFileSources(mapFS, false, nil)
sources, err := collectFilesystemSources(mapFS, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1)
@ -135,7 +142,7 @@ func TestCollectFileSources(t *testing.T) {
"001_foo.sql": sqlMapFile,
"01_bar.sql": sqlMapFile,
}
_, err := collectFileSources(mapFS, false, nil)
_, err := collectFilesystemSources(mapFS, false, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
})
@ -151,7 +158,7 @@ func TestCollectFileSources(t *testing.T) {
t.Helper()
f, err := fs.Sub(mapFS, dirpath)
check.NoError(t, err)
got, err := collectFileSources(f, false, nil)
got, err := collectFilesystemSources(f, false, nil)
check.NoError(t, err)
check.Number(t, len(got.sqlSources), len(sqlSources))
check.Number(t, len(got.goSources), 0)
@ -160,15 +167,15 @@ func TestCollectFileSources(t *testing.T) {
}
}
assertDirpath(".", []Source{
NewSource(TypeSQL, "876_a.sql", 876),
newSource(TypeSQL, "876_a.sql", 876),
})
assertDirpath("dir1", []Source{
NewSource(TypeSQL, "101_a.sql", 101),
NewSource(TypeSQL, "102_b.sql", 102),
NewSource(TypeSQL, "103_c.sql", 103),
newSource(TypeSQL, "101_a.sql", 101),
newSource(TypeSQL, "102_b.sql", 102),
newSource(TypeSQL, "103_c.sql", 103),
})
assertDirpath("dir2", []Source{
NewSource(TypeSQL, "201_a.sql", 201),
newSource(TypeSQL, "201_a.sql", 201),
})
assertDirpath("dir3", nil)
})
@ -187,7 +194,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 1)
check.Equal(t, len(sources.goSources), 2)
@ -205,9 +212,9 @@ func TestMerge(t *testing.T) {
})
check.NoError(t, err)
check.Number(t, len(migrations), 3)
assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], NewSource(TypeGo, "00003_baz.go", 3))
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
})
t.Run("unregistered_all", func(t *testing.T) {
_, err := merge(sources, nil)
@ -243,7 +250,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
@ -253,11 +260,11 @@ func TestMerge(t *testing.T) {
})
check.NoError(t, err)
check.Number(t, len(migrations), 5)
assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], NewSource(TypeSQL, "00002_bar.sql", 2))
assertMigration(t, migrations[2], NewSource(TypeGo, "", 3))
assertMigration(t, migrations[3], NewSource(TypeSQL, "00005_baz.sql", 5))
assertMigration(t, migrations[4], NewSource(TypeGo, "", 6))
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5))
assertMigration(t, migrations[4], newSource(TypeGo, "", 6))
})
})
t.Run("partial_go_files_on_disk", func(t *testing.T) {
@ -267,7 +274,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
@ -279,15 +286,15 @@ func TestMerge(t *testing.T) {
})
check.NoError(t, err)
check.Number(t, len(migrations), 4)
assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], NewSource(TypeGo, "", 3))
assertMigration(t, migrations[3], NewSource(TypeGo, "", 6))
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
})
})
}
func TestFindMissingMigrations(t *testing.T) {
func TestCheckMissingMigrations(t *testing.T) {
t.Parallel()
t.Run("db_has_max_version", func(t *testing.T) {
@ -302,24 +309,24 @@ func TestFindMissingMigrations(t *testing.T) {
{Version: 7}, // <-- database max version_id
}
fsMigrations := []*migration{
newMigration(1),
newMigration(2), // missing migration
newMigration(3),
newMigration(4),
newMigration(5),
newMigration(6), // missing migration
newMigration(7), // ----- database max version_id -----
newMigration(8), // new migration
newMigrationVersion(1),
newMigrationVersion(2), // missing migration
newMigrationVersion(3),
newMigrationVersion(4),
newMigrationVersion(5),
newMigrationVersion(6), // missing migration
newMigrationVersion(7), // ----- database max version_id -----
newMigrationVersion(8), // new migration
}
got := findMissingMigrations(dbMigrations, fsMigrations)
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 2)
check.Number(t, got[1].versionID, 6)
// Sanity check.
check.Number(t, len(findMissingMigrations(nil, nil)), 0)
check.Number(t, len(findMissingMigrations(dbMigrations, nil)), 0)
check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0)
check.Number(t, len(checkMissingMigrations(nil, nil)), 0)
check.Number(t, len(checkMissingMigrations(dbMigrations, nil)), 0)
check.Number(t, len(checkMissingMigrations(nil, fsMigrations)), 0)
})
t.Run("fs_has_max_version", func(t *testing.T) {
dbMigrations := []*database.ListMigrationsResult{
@ -328,17 +335,17 @@ func TestFindMissingMigrations(t *testing.T) {
{Version: 2},
}
fsMigrations := []*migration{
newMigration(3), // new migration
newMigration(4), // new migration
newMigrationVersion(3), // new migration
newMigrationVersion(4), // new migration
}
got := findMissingMigrations(dbMigrations, fsMigrations)
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 3)
check.Number(t, got[1].versionID, 4)
})
}
func newMigration(version int64) *migration {
func newMigrationVersion(version int64) *migration {
return &migration{
Source: Source{
Version: version,
@ -368,6 +375,14 @@ func newSQLOnlyFS() fstest.MapFS {
}
}
func newSource(t MigrationType, fullpath string, version int64) Source {
return Source{
Type: t,
Path: fullpath,
Version: version,
}
}
var (
sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)}
)

View File

@ -44,7 +44,7 @@ func (m *migration) useTx(direction bool) bool {
func (m *migration) isEmpty(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL == nil || m.SQL.IsEmpty(direction)
return m.SQL == nil || m.SQL.isEmpty(direction)
case TypeGo:
return m.Go == nil || m.Go.isEmpty(direction)
}
@ -102,7 +102,7 @@ func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool)
type goMigration struct {
fullpath string
up, down *GoMigration
up, down *GoMigrationFunc
}
func (g *goMigration) isEmpty(direction bool) bool {
@ -115,7 +115,7 @@ func (g *goMigration) isEmpty(direction bool) bool {
return g.down == nil
}
func newGoMigration(fullpath string, up, down *GoMigration) *goMigration {
func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration {
return &goMigration{
fullpath: fullpath,
up: up,
@ -163,7 +163,7 @@ type sqlMigration struct {
DownStatements []string
}
func (s *sqlMigration) IsEmpty(direction bool) bool {
func (s *sqlMigration) isEmpty(direction bool) bool {
if direction {
return len(s.UpStatements) == 0
}

View File

@ -5,9 +5,11 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3"
)
type Migration struct {
type MigrationCopy struct {
Version int64
Source string // path to .sql script or go file
Registered bool
@ -15,13 +17,13 @@ type Migration struct {
UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
}
var registeredGoMigrations = make(map[int64]*Migration)
var registeredGoMigrations = make(map[int64]*MigrationCopy)
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
// the migrations are nil or if a migration with the same version has already been registered.
//
// Not safe for concurrent use.
func SetGlobalGoMigrations(migrations []*Migration) error {
func SetGlobalGoMigrations(migrations []*MigrationCopy) error {
for _, m := range migrations {
if m == nil {
return errors.New("cannot register nil go migration")
@ -35,7 +37,7 @@ func SetGlobalGoMigrations(migrations []*Migration) error {
if m.Source != "" {
// If the source is set, expect it to be a file path with a numeric component that
// matches the version.
version, err := NumericComponent(m.Source)
version, err := goose.NumericComponent(m.Source)
if err != nil {
return err
}
@ -62,5 +64,5 @@ func SetGlobalGoMigrations(migrations []*Migration) error {
//
// Not safe for concurrent use.
func ResetGlobalGoMigrations() {
registeredGoMigrations = make(map[int64]*Migration)
registeredGoMigrations = make(map[int64]*MigrationCopy)
}

View File

@ -12,6 +12,21 @@ import (
"github.com/pressly/goose/v3/database"
)
// Provider is a goose migration provider.
type Provider struct {
// mu protects all accesses to the provider and must be held when calling operations on the
// database.
mu sync.Mutex
db *sql.DB
fsys fs.FS
cfg config
store database.Store
// migrations are ordered by version in ascending order.
migrations []*migration
}
// NewProvider returns a new goose Provider.
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
@ -46,11 +61,13 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
return nil, err
}
}
// Allow users to specify a custom store implementation, but only if they don't specify a
// dialect. If they specify a dialect, we'll use the default store implementation.
if dialect == "" && cfg.store == nil {
return nil, errors.New("dialect must not be empty")
}
if dialect != "" && cfg.store != nil {
return nil, errors.New("cannot set both dialect and store")
return nil, errors.New("cannot set both dialect and custom store")
}
var store database.Store
if dialect != "" {
@ -65,6 +82,16 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
if store.Tablename() == "" {
return nil, errors.New("invalid store implementation: table name must not be empty")
}
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
}
func newProvider(
db *sql.DB,
store database.Store,
fsys fs.FS,
cfg config,
global map[int64]*MigrationCopy,
) (*Provider, error) {
// Collect migrations from the filesystem and merge with registered migrations.
//
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
@ -73,13 +100,10 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
// TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to
// return an error if there are any SQL parsing errors. This adds a bit overhead to startup
// though, so we should make it optional.
sources, err := collectFileSources(fsys, false, cfg.excludes)
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludes)
if err != nil {
return nil, err
}
//
// TODO(mf): move the merging of Go migrations into a separate function.
//
registered := make(map[int64]*goMigration)
// Add user-registered Go migrations.
for version, m := range cfg.registered {
@ -87,7 +111,7 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
}
// Add init() functions. This is a bit ugly because we need to convert from the old Migration
// struct to the new goMigration struct.
for version, m := range registeredGoMigrations {
for version, m := range global {
if _, ok := registered[version]; ok {
return nil, fmt.Errorf("go migration with version %d already registered", version)
}
@ -103,27 +127,27 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
}
// Up
if m.UpFnContext != nil {
g.up = &GoMigration{
g.up = &GoMigrationFunc{
Run: m.UpFnContext,
}
} else if m.UpFnNoTxContext != nil {
g.up = &GoMigration{
g.up = &GoMigrationFunc{
RunNoTx: m.UpFnNoTxContext,
}
}
// Down
if m.DownFnContext != nil {
g.down = &GoMigration{
g.down = &GoMigrationFunc{
Run: m.DownFnContext,
}
} else if m.DownFnNoTxContext != nil {
g.down = &GoMigration{
g.down = &GoMigrationFunc{
RunNoTx: m.DownFnNoTxContext,
}
}
registered[version] = g
}
migrations, err := merge(sources, registered)
migrations, err := merge(filesystemSources, registered)
if err != nil {
return nil, err
}
@ -139,21 +163,6 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi
}, nil
}
// Provider is a goose migration provider.
type Provider struct {
// mu protects all accesses to the provider and must be held when calling operations on the
// database.
mu sync.Mutex
db *sql.DB
fsys fs.FS
cfg config
store database.Store
// migrations are ordered by version in ascending order.
migrations []*migration
}
// Status returns the status of all migrations, merging the list of migrations from the database and
// filesystem. The returned items are ordered by version, in ascending order.
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {

View File

@ -96,8 +96,8 @@ func WithExcludes(excludes []string) ProviderOption {
})
}
// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration].
type GoMigration struct {
// GoMigrationFunc is a user-defined Go migration, registered using the option [WithGoMigration].
type GoMigrationFunc struct {
// One of the following must be set:
Run func(context.Context, *sql.Tx) error
// -- OR --
@ -109,7 +109,7 @@ type GoMigration struct {
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
// set.
func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
func WithGoMigration(version int64, up, down *GoMigrationFunc) ProviderOption {
return configFunc(func(c *config) error {
if version < 1 {
return errors.New("version must be greater than zero")
@ -143,25 +143,27 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption {
})
}
// WithAllowMissing allows the provider to apply missing (out-of-order) migrations.
// WithAllowedMissing allows the provider to apply missing (out-of-order) migrations. By default,
// goose will raise an error if it encounters a missing migration.
//
// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
// by new migrations.
func WithAllowMissing(b bool) ProviderOption {
func WithAllowedMissing(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.allowMissing = b
return nil
})
}
// WithNoVersioning disables versioning. Disabling versioning allows applying migrations without
// tracking the versions in the database schema table. Useful for tests, seeding a database or
// running ad-hoc queries.
func WithNoVersioning(b bool) ProviderOption {
// WithDisabledVersioning disables versioning. Disabling versioning allows applying migrations
// without tracking the versions in the database schema table. Useful for tests, seeding a database
// or running ad-hoc queries. By default, goose will track all versions in the database schema
// table.
func WithDisabledVersioning(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.noVersioning = b
c.disableVersioning = b
return nil
})
}
@ -181,8 +183,8 @@ type config struct {
sessionLocker lock.SessionLocker
// Feature
noVersioning bool
allowMissing bool
disableVersioning bool
allowMissing bool
}
type configFunc func(*config) error

View File

@ -35,12 +35,12 @@ func TestProvider(t *testing.T) {
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
check.Equal(t, sources[0], provider.NewSource(provider.TypeSQL, "001_foo.sql", 1))
check.Equal(t, sources[1], provider.NewSource(provider.TypeSQL, "002_bar.sql", 2))
check.Equal(t, sources[0], newSource(provider.TypeSQL, "001_foo.sql", 1))
check.Equal(t, sources[1], newSource(provider.TypeSQL, "002_bar.sql", 2))
t.Run("duplicate_go", func(t *testing.T) {
// Not parallel because it modifies global state.
register := []*provider.Migration{
register := []*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: nil,
@ -62,13 +62,13 @@ func TestProvider(t *testing.T) {
db := newDB(t)
// explicit
_, err := provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}),
provider.WithGoMigration(1, &provider.GoMigrationFunc{Run: nil}, &provider.GoMigrationFunc{Run: nil}),
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 must have an up function")
})
t.Run("duplicate_up", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
@ -80,7 +80,7 @@ func TestProvider(t *testing.T) {
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
})
t.Run("duplicate_down", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
@ -92,7 +92,7 @@ func TestProvider(t *testing.T) {
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
})
t.Run("not_registered", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go",
},
@ -102,7 +102,7 @@ func TestProvider(t *testing.T) {
check.Contains(t, err.Error(), "migration must be registered")
})
t.Run("zero_not_allowed", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.Migration{
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 0,
},

View File

@ -34,13 +34,13 @@ func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*Mi
return nil, nil
}
var apply []*migration
if p.cfg.noVersioning {
if p.cfg.disableVersioning {
apply = p.migrations
} else {
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
// support the allow missing (out-of-order) feature. For users that don't use this feature, we
// could just query the database for the current max version and then apply migrations greater
// than that version.
// optimize(mf): Listing all migrations from the database isn't great. This is only required
// to support the allow missing (out-of-order) feature. For users that don't use this
// feature, we could just query the database for the current max version and then apply
// migrations greater than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
@ -76,13 +76,13 @@ func (p *Provider) resolveUpMigrations(
dbMaxVersion = m.Version
}
}
missingMigrations := findMissingMigrations(dbVersions, p.migrations)
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
// migrations entirely. At the moment this is not supported, but leaving this comment because
// that's where that logic would be handled.
//
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3.
// Not sure if this is a common use case, but it's possible.
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
// sure if this is a common use case, but it's possible.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
@ -127,7 +127,7 @@ func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ [
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
if p.cfg.disableVersioning {
downMigrations := p.migrations
if downByOne {
last := p.migrations[len(p.migrations)-1]
@ -245,7 +245,7 @@ func (p *Provider) runIndividually(
if err := m.run(ctx, tx, direction); err != nil {
return err
}
if p.cfg.noVersioning {
if p.cfg.disableVersioning {
return nil
}
if direction {
@ -268,7 +268,7 @@ func (p *Provider) runIndividually(
return err
}
}
if p.cfg.noVersioning {
if p.cfg.disableVersioning {
return nil
}
if direction {
@ -329,7 +329,7 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
}
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
// need the version table because there is no versioning.
if !p.cfg.noVersioning {
if !p.cfg.disableVersioning {
if err := p.ensureVersionTable(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
}
@ -370,7 +370,7 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
}
if p.cfg.noVersioning {
if p.cfg.disableVersioning {
return nil
}
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
@ -382,9 +382,9 @@ type missingMigration struct {
filename string
}
// findMissingMigrations returns a list of migrations that are missing from the database. A missing
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
// migration is one that has a version less than the max version in the database.
func findMissingMigrations(
func checkMissingMigrations(
dbMigrations []*database.ListMigrationsResult,
fsMigrations []*migration,
) []missingMigration {

View File

@ -78,24 +78,24 @@ func TestProviderRun(t *testing.T) {
res, err := p.Up(ctx)
check.NoError(t, err)
check.Number(t, len(res), numCount)
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
assertResult(t, res[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, res[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, res[2], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
assertResult(t, res[4], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
assertResult(t, res[5], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
assertResult(t, res[6], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
// Test Down
res, err = p.DownTo(ctx, 0)
check.NoError(t, err)
check.Number(t, len(res), numCount)
assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
assertResult(t, res[0], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
assertResult(t, res[1], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
assertResult(t, res[2], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
assertResult(t, res[4], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
assertResult(t, res[5], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
assertResult(t, res[6], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
})
t.Run("up_and_down_by_one", func(t *testing.T) {
ctx := context.Background()
@ -149,8 +149,8 @@ func TestProviderRun(t *testing.T) {
results, err := p.UpTo(ctx, upToVersion)
check.NoError(t, err)
check.Number(t, len(results), upToVersion)
assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, results[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, results[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
// Fetch the goose version from DB
currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err)
@ -272,26 +272,26 @@ func TestProviderRun(t *testing.T) {
status, err := p.Status(ctx)
check.NoError(t, err)
check.Number(t, len(status), numCount)
assertStatus(t, status[0], provider.StatePending, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), true)
assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), true)
assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), true)
assertStatus(t, status[3], provider.StatePending, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), true)
assertStatus(t, status[4], provider.StatePending, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), true)
assertStatus(t, status[5], provider.StatePending, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), true)
assertStatus(t, status[6], provider.StatePending, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true)
assertStatus(t, status[0], provider.StatePending, newSource(provider.TypeSQL, "00001_users_table.sql", 1), true)
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), true)
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), true)
assertStatus(t, status[3], provider.StatePending, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), true)
assertStatus(t, status[4], provider.StatePending, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), true)
assertStatus(t, status[5], provider.StatePending, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), true)
assertStatus(t, status[6], provider.StatePending, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true)
// Apply all migrations
_, err = p.Up(ctx)
check.NoError(t, err)
status, err = p.Status(ctx)
check.NoError(t, err)
check.Number(t, len(status), numCount)
assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), false)
assertStatus(t, status[2], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), false)
assertStatus(t, status[3], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), false)
assertStatus(t, status[4], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), false)
assertStatus(t, status[5], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), false)
assertStatus(t, status[6], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false)
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], provider.StateApplied, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), false)
assertStatus(t, status[2], provider.StateApplied, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), false)
assertStatus(t, status[3], provider.StateApplied, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), false)
assertStatus(t, status[4], provider.StateApplied, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), false)
assertStatus(t, status[5], provider.StateApplied, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), false)
assertStatus(t, status[6], provider.StateApplied, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false)
})
t.Run("tx_partial_errors", func(t *testing.T) {
countOwners := func(db *sql.DB) (int, error) {
@ -333,7 +333,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
// Check Results field
check.Number(t, len(expected.Applied), 1)
assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, expected.Applied[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
// Check Failed field
check.Bool(t, expected.Failed != nil, true)
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
@ -351,9 +351,9 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
status, err := p.Status(ctx)
check.NoError(t, err)
check.Number(t, len(status), 3)
assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_partial_error.sql", 2), true)
assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_insert_data.sql", 3), true)
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_partial_error.sql", 2), true)
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_insert_data.sql", 3), true)
})
}
@ -488,7 +488,7 @@ func TestNoVersioning(t *testing.T) {
)
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
provider.WithVerbose(testing.Verbose()),
provider.WithNoVersioning(false), // This is the default.
provider.WithDisabledVersioning(false), // This is the default.
)
check.Number(t, len(p.ListSources()), 3)
check.NoError(t, err)
@ -501,7 +501,7 @@ func TestNoVersioning(t *testing.T) {
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
provider.WithVerbose(testing.Verbose()),
provider.WithNoVersioning(true), // Provider with no versioning.
provider.WithDisabledVersioning(true), // Provider with no versioning.
)
check.NoError(t, err)
check.Number(t, len(p.ListSources()), 2)
@ -553,7 +553,7 @@ func TestAllowMissing(t *testing.T) {
t.Run("missing_now_allowed", func(t *testing.T) {
db := newDB(t)
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
provider.WithAllowMissing(false),
provider.WithAllowedMissing(false),
)
check.NoError(t, err)
@ -608,7 +608,7 @@ func TestAllowMissing(t *testing.T) {
t.Run("missing_allowed", func(t *testing.T) {
db := newDB(t)
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
provider.WithAllowMissing(true),
provider.WithAllowedMissing(true),
)
check.NoError(t, err)
@ -703,7 +703,7 @@ func TestGoOnly(t *testing.T) {
t.Run("with_tx", func(t *testing.T) {
ctx := context.Background()
register := []*provider.Migration{
register := []*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
@ -718,8 +718,8 @@ func TestGoOnly(t *testing.T) {
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(
2,
&provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&provider.GoMigration{Run: newTxFn("DELETE FROM users")},
&provider.GoMigrationFunc{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&provider.GoMigrationFunc{Run: newTxFn("DELETE FROM users")},
),
)
check.NoError(t, err)
@ -730,29 +730,29 @@ func TestGoOnly(t *testing.T) {
// Apply migration 1
res, err := p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2
res, err = p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
check.Number(t, countUser(db), 3)
// Rollback migration 2
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
check.Number(t, countUser(db), 0)
// Rollback migration 1
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
// Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false)
})
t.Run("with_db", func(t *testing.T) {
ctx := context.Background()
register := []*provider.Migration{
register := []*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
@ -767,8 +767,8 @@ func TestGoOnly(t *testing.T) {
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(
2,
&provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&provider.GoMigration{RunNoTx: newDBFn("DELETE FROM users")},
&provider.GoMigrationFunc{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&provider.GoMigrationFunc{RunNoTx: newDBFn("DELETE FROM users")},
),
)
check.NoError(t, err)
@ -779,23 +779,23 @@ func TestGoOnly(t *testing.T) {
// Apply migration 1
res, err := p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2
res, err = p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
check.Number(t, countUser(db), 3)
// Rollback migration 2
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
check.Number(t, countUser(db), 0)
// Rollback migration 1
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
// Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false)
})
@ -1148,6 +1148,14 @@ func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType,
}
}
func newSource(t provider.MigrationType, fullpath string, version int64) provider.Source {
return provider.Source{
Type: t,
Path: fullpath,
Version: version,
}
}
func newMapFile(data string) *fstest.MapFile {
return &fstest.MapFile{
Data: []byte(data),