mirror of https://github.com/pressly/goose.git
325 lines
11 KiB
Go
325 lines
11 KiB
Go
package goose
|
|
|
|
import (
|
|
"io/fs"
|
|
"testing"
|
|
"testing/fstest"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestCollectFileSources(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("nil_fsys", func(t *testing.T) {
|
|
sources, err := collectFilesystemSources(nil, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, sources)
|
|
require.Empty(t, sources.goSources)
|
|
require.Empty(t, sources.sqlSources)
|
|
})
|
|
t.Run("noop_fsys", func(t *testing.T) {
|
|
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, sources)
|
|
require.Empty(t, sources.goSources)
|
|
require.Empty(t, sources.sqlSources)
|
|
})
|
|
t.Run("empty_fsys", func(t *testing.T) {
|
|
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Empty(t, sources.goSources)
|
|
require.Empty(t, sources.sqlSources)
|
|
require.NotNil(t, sources)
|
|
})
|
|
t.Run("incorrect_fsys", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
"00000_foo.sql": sqlMapFile,
|
|
}
|
|
// strict disable - should not error
|
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Empty(t, sources.goSources)
|
|
require.Empty(t, sources.sqlSources)
|
|
// strict enabled - should error
|
|
_, err = collectFilesystemSources(mapFS, true, nil, nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "migration version must be greater than zero")
|
|
})
|
|
t.Run("collect", func(t *testing.T) {
|
|
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
|
require.NoError(t, err)
|
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, sources.sqlSources, 4)
|
|
require.Empty(t, sources.goSources)
|
|
expected := fileSources{
|
|
sqlSources: []Source{
|
|
newSource(TypeSQL, "00001_foo.sql", 1),
|
|
newSource(TypeSQL, "00002_bar.sql", 2),
|
|
newSource(TypeSQL, "00003_baz.sql", 3),
|
|
newSource(TypeSQL, "00110_qux.sql", 110),
|
|
},
|
|
}
|
|
for i := 0; i < len(sources.sqlSources); i++ {
|
|
require.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
|
}
|
|
})
|
|
t.Run("excludes", func(t *testing.T) {
|
|
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
|
require.NoError(t, err)
|
|
sources, err := collectFilesystemSources(
|
|
fsys,
|
|
false,
|
|
// exclude 2 files explicitly
|
|
map[string]bool{
|
|
"00002_bar.sql": true,
|
|
"00110_qux.sql": true,
|
|
},
|
|
nil,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, sources.sqlSources, 2)
|
|
require.Empty(t, sources.goSources)
|
|
expected := fileSources{
|
|
sqlSources: []Source{
|
|
newSource(TypeSQL, "00001_foo.sql", 1),
|
|
newSource(TypeSQL, "00003_baz.sql", 3),
|
|
},
|
|
}
|
|
for i := 0; i < len(sources.sqlSources); i++ {
|
|
require.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
|
}
|
|
})
|
|
t.Run("strict", func(t *testing.T) {
|
|
mapFS := newSQLOnlyFS()
|
|
// Add a file with no version number
|
|
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
|
fsys, err := fs.Sub(mapFS, "migrations")
|
|
require.NoError(t, err)
|
|
_, err = collectFilesystemSources(fsys, true, nil, nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
|
})
|
|
t.Run("skip_go_test_files", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
"1_foo.sql": sqlMapFile,
|
|
"2_bar.sql": sqlMapFile,
|
|
"3_baz.sql": sqlMapFile,
|
|
"4_qux.sql": sqlMapFile,
|
|
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
|
}
|
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, sources.sqlSources, 4)
|
|
require.Empty(t, sources.goSources)
|
|
})
|
|
t.Run("skip_random_files", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
"1_foo.sql": sqlMapFile,
|
|
"4_something.go": {Data: []byte(`package goose`)},
|
|
"5_qux.sql": sqlMapFile,
|
|
"README.md": {Data: []byte(`# README`)},
|
|
"LICENSE": {Data: []byte(`MIT`)},
|
|
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
|
|
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
|
}
|
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, sources.sqlSources, 2)
|
|
require.Len(t, sources.goSources, 1)
|
|
// 1
|
|
require.Equal(t, "1_foo.sql", sources.sqlSources[0].Path)
|
|
require.EqualValues(t, 1, sources.sqlSources[0].Version)
|
|
// 2
|
|
require.Equal(t, "5_qux.sql", sources.sqlSources[1].Path)
|
|
require.EqualValues(t, 5, sources.sqlSources[1].Version)
|
|
// 3
|
|
require.Equal(t, "4_something.go", sources.goSources[0].Path)
|
|
require.EqualValues(t, 4, sources.goSources[0].Version)
|
|
})
|
|
t.Run("duplicate_versions", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
"001_foo.sql": sqlMapFile,
|
|
"01_bar.sql": sqlMapFile,
|
|
}
|
|
_, err := collectFilesystemSources(mapFS, false, nil, nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "found duplicate migration version 1")
|
|
})
|
|
t.Run("dirpath", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
"dir1/101_a.sql": sqlMapFile,
|
|
"dir1/102_b.sql": sqlMapFile,
|
|
"dir1/103_c.sql": sqlMapFile,
|
|
"dir2/201_a.sql": sqlMapFile,
|
|
"876_a.sql": sqlMapFile,
|
|
}
|
|
assertDirpath := func(dirpath string, sqlSources []Source) {
|
|
t.Helper()
|
|
f, err := fs.Sub(mapFS, dirpath)
|
|
require.NoError(t, err)
|
|
got, err := collectFilesystemSources(f, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, sqlSources, len(got.sqlSources))
|
|
require.Empty(t, got.goSources)
|
|
for i := 0; i < len(got.sqlSources); i++ {
|
|
require.Equal(t, got.sqlSources[i], sqlSources[i])
|
|
}
|
|
}
|
|
assertDirpath(".", []Source{
|
|
newSource(TypeSQL, "876_a.sql", 876),
|
|
})
|
|
assertDirpath("dir1", []Source{
|
|
newSource(TypeSQL, "101_a.sql", 101),
|
|
newSource(TypeSQL, "102_b.sql", 102),
|
|
newSource(TypeSQL, "103_c.sql", 103),
|
|
})
|
|
assertDirpath("dir2", []Source{
|
|
newSource(TypeSQL, "201_a.sql", 201),
|
|
})
|
|
assertDirpath("dir3", nil)
|
|
})
|
|
}
|
|
|
|
func TestMerge(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("with_go_files_on_disk", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
// SQL
|
|
"migrations/00001_foo.sql": sqlMapFile,
|
|
// Go
|
|
"migrations/00002_bar.go": {Data: []byte(`package migrations`)},
|
|
"migrations/00003_baz.go": {Data: []byte(`package migrations`)},
|
|
}
|
|
fsys, err := fs.Sub(mapFS, "migrations")
|
|
require.NoError(t, err)
|
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
|
require.NoError(t, err)
|
|
require.Len(t, sources.sqlSources, 1)
|
|
require.Len(t, sources.goSources, 2)
|
|
t.Run("valid", func(t *testing.T) {
|
|
registered := map[int64]*Migration{
|
|
2: NewGoMigration(2, nil, nil),
|
|
3: NewGoMigration(3, nil, nil),
|
|
}
|
|
migrations, err := merge(sources, registered)
|
|
require.NoError(t, err)
|
|
require.Len(t, migrations, 3)
|
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
|
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
|
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
|
|
})
|
|
t.Run("unregistered_all", func(t *testing.T) {
|
|
_, err := merge(sources, nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "error: detected 2 unregistered Go files:")
|
|
require.Contains(t, err.Error(), "00002_bar.go")
|
|
require.Contains(t, err.Error(), "00003_baz.go")
|
|
})
|
|
t.Run("unregistered_some", func(t *testing.T) {
|
|
_, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
|
require.Contains(t, err.Error(), "00003_baz.go")
|
|
})
|
|
t.Run("duplicate_sql", func(t *testing.T) {
|
|
_, err := merge(sources, map[int64]*Migration{
|
|
1: NewGoMigration(1, nil, nil), // duplicate. SQL already exists.
|
|
2: NewGoMigration(2, nil, nil),
|
|
3: NewGoMigration(3, nil, nil),
|
|
})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "found duplicate migration version 1")
|
|
})
|
|
})
|
|
t.Run("no_go_files_on_disk", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
// SQL
|
|
"migrations/00001_foo.sql": sqlMapFile,
|
|
"migrations/00002_bar.sql": sqlMapFile,
|
|
"migrations/00005_baz.sql": sqlMapFile,
|
|
}
|
|
fsys, err := fs.Sub(mapFS, "migrations")
|
|
require.NoError(t, err)
|
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
|
require.NoError(t, err)
|
|
t.Run("unregistered_all", func(t *testing.T) {
|
|
migrations, err := merge(sources, map[int64]*Migration{
|
|
3: NewGoMigration(3, nil, nil),
|
|
// 4 is missing
|
|
6: NewGoMigration(6, nil, nil),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, migrations, 5)
|
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
|
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
|
|
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
|
assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5))
|
|
assertMigration(t, migrations[4], newSource(TypeGo, "", 6))
|
|
})
|
|
})
|
|
t.Run("partial_go_files_on_disk", func(t *testing.T) {
|
|
mapFS := fstest.MapFS{
|
|
"migrations/00001_foo.sql": sqlMapFile,
|
|
"migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)},
|
|
}
|
|
fsys, err := fs.Sub(mapFS, "migrations")
|
|
require.NoError(t, err)
|
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
|
require.NoError(t, err)
|
|
t.Run("unregistered_all", func(t *testing.T) {
|
|
migrations, err := merge(sources, map[int64]*Migration{
|
|
// This is the only Go file on disk.
|
|
2: NewGoMigration(2, nil, nil),
|
|
// These are not on disk. Explicitly registered.
|
|
3: NewGoMigration(3, nil, nil),
|
|
6: NewGoMigration(6, nil, nil),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, migrations, 4)
|
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
|
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
|
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
|
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
|
|
})
|
|
})
|
|
}
|
|
|
|
func assertMigration(t *testing.T, got *Migration, want Source) {
|
|
t.Helper()
|
|
require.Equal(t, want.Type, got.Type)
|
|
require.Equal(t, want.Version, got.Version)
|
|
require.Equal(t, want.Path, got.Source)
|
|
switch got.Type {
|
|
case TypeGo:
|
|
require.NotNil(t, got.goUp)
|
|
require.NotNil(t, got.goDown)
|
|
case TypeSQL:
|
|
require.False(t, got.sql.Parsed)
|
|
default:
|
|
t.Fatalf("unknown migration type: %s", got.Type)
|
|
}
|
|
}
|
|
|
|
func newSQLOnlyFS() fstest.MapFS {
|
|
return fstest.MapFS{
|
|
"migrations/00001_foo.sql": sqlMapFile,
|
|
"migrations/00002_bar.sql": sqlMapFile,
|
|
"migrations/00003_baz.sql": sqlMapFile,
|
|
"migrations/00110_qux.sql": sqlMapFile,
|
|
}
|
|
}
|
|
|
|
func newSource(t MigrationType, fullpath string, version int64) Source {
|
|
return Source{
|
|
Type: t,
|
|
Path: fullpath,
|
|
Version: version,
|
|
}
|
|
}
|
|
|
|
var (
|
|
sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)}
|
|
)
|