mirror of https://github.com/pressly/goose.git
feat(experimental): add collect migrations logic and new Provider options (#615)
parent
fe8fe975d8
commit
68853f91ea
3
Makefile
3
Makefile
|
@ -33,6 +33,9 @@ tools:
|
|||
test-packages:
|
||||
go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
|
||||
|
||||
test-packages-short:
|
||||
go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
|
||||
|
||||
test-e2e: test-e2e-postgres test-e2e-mysql test-e2e-clickhouse test-e2e-vertica
|
||||
|
||||
test-e2e-postgres:
|
||||
|
|
|
@ -11,6 +11,9 @@ import (
|
|||
|
||||
func TestSequential(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.Skip("skip long running test")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
defer os.Remove("./bin/create-goose") // clean up
|
||||
|
|
|
@ -11,6 +11,9 @@ import (
|
|||
|
||||
func TestFix(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.Skip("skip long running test")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
defer os.Remove("./bin/fix-goose") // clean up
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
)
|
||||
|
||||
func TestParsingGoMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
|
@ -38,6 +39,7 @@ func TestParsingGoMigrations(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParsingGoMigrationsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := parseGoFile(strings.NewReader(emptyInit))
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "no registered goose functions")
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/internal/migrate"
|
||||
)
|
||||
|
||||
// fileSources represents a collection of migration files on the filesystem.
|
||||
type fileSources struct {
|
||||
sqlSources []Source
|
||||
goSources []Source
|
||||
}
|
||||
|
||||
// collectFileSources scans the file system for migration files that have a numeric prefix (greater
|
||||
// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil,
|
||||
// in which case an empty fileSources is returned.
|
||||
//
|
||||
// If strict is true, then any error parsing the numeric component of the filename will result in an
|
||||
// error. The file is skipped otherwise.
|
||||
//
|
||||
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
|
||||
// migration sources from the filesystem.
|
||||
func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
|
||||
if fsys == nil {
|
||||
return new(fileSources), nil
|
||||
}
|
||||
sources := new(fileSources)
|
||||
versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath)
|
||||
for _, pattern := range []string{
|
||||
"*.sql",
|
||||
"*.go",
|
||||
} {
|
||||
files, err := fs.Glob(fsys, pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
|
||||
}
|
||||
for _, fullpath := range files {
|
||||
base := filepath.Base(fullpath)
|
||||
// Skip explicit excludes or Go test files.
|
||||
if excludes[base] || strings.HasSuffix(base, "_test.go") {
|
||||
continue
|
||||
}
|
||||
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
|
||||
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
|
||||
// filenames, but still have versioned migrations within the same directory. For
|
||||
// example, a user could have a helpers.go file which contains unexported helper
|
||||
// functions for migrations.
|
||||
version, err := goose.NumericComponent(base)
|
||||
if err != nil {
|
||||
if strict {
|
||||
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Ensure there are no duplicate versions.
|
||||
if existing, ok := versionToBaseLookup[version]; ok {
|
||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||
version,
|
||||
existing,
|
||||
base,
|
||||
)
|
||||
}
|
||||
switch filepath.Ext(base) {
|
||||
case ".sql":
|
||||
sources.sqlSources = append(sources.sqlSources, Source{
|
||||
Fullpath: fullpath,
|
||||
Version: version,
|
||||
})
|
||||
case ".go":
|
||||
sources.goSources = append(sources.goSources, Source{
|
||||
Fullpath: fullpath,
|
||||
Version: version,
|
||||
})
|
||||
default:
|
||||
// Should never happen since we already filtered out all other file types.
|
||||
return nil, fmt.Errorf("unknown migration type: %s", base)
|
||||
}
|
||||
// Add the version to the lookup map.
|
||||
versionToBaseLookup[version] = base
|
||||
}
|
||||
}
|
||||
return sources, nil
|
||||
}
|
||||
|
||||
func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) {
|
||||
var migrations []*migrate.Migration
|
||||
migrationLookup := make(map[int64]*migrate.Migration)
|
||||
// Add all SQL migrations to the list of migrations.
|
||||
for _, s := range sources.sqlSources {
|
||||
m := &migrate.Migration{
|
||||
Type: migrate.TypeSQL,
|
||||
Fullpath: s.Fullpath,
|
||||
Version: s.Version,
|
||||
SQLParsed: false,
|
||||
}
|
||||
migrations = append(migrations, m)
|
||||
migrationLookup[s.Version] = m
|
||||
}
|
||||
// If there are no Go files in the filesystem and no registered Go migrations, return early.
|
||||
if len(sources.goSources) == 0 && len(registerd) == 0 {
|
||||
return migrations, nil
|
||||
}
|
||||
// Return an error if the given sources contain a versioned Go migration that has not been
|
||||
// registered. This is a sanity check to ensure users didn't accidentally create a valid looking
|
||||
// Go migration file on disk and forget to register it.
|
||||
//
|
||||
// This is almost always a user error.
|
||||
var unregistered []string
|
||||
for _, s := range sources.goSources {
|
||||
if _, ok := registerd[s.Version]; !ok {
|
||||
unregistered = append(unregistered, s.Fullpath)
|
||||
}
|
||||
}
|
||||
if len(unregistered) > 0 {
|
||||
return nil, unregisteredError(unregistered)
|
||||
}
|
||||
// Add all registered Go migrations to the list of migrations, checking for duplicate versions.
|
||||
//
|
||||
// Important, users can register Go migrations manually via goose.Add_ functions. These
|
||||
// migrations may not have a corresponding file on disk. Which is fine! We include them
|
||||
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
|
||||
// the SQL migration files.
|
||||
for _, r := range registerd {
|
||||
// Ensure there are no duplicate versions.
|
||||
if existing, ok := migrationLookup[r.Version]; ok {
|
||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||
r.Version,
|
||||
existing,
|
||||
filepath.Base(r.Source),
|
||||
)
|
||||
}
|
||||
m := &migrate.Migration{
|
||||
Fullpath: r.Source, // May be empty if the migration was registered manually.
|
||||
Version: r.Version,
|
||||
Type: migrate.TypeGo,
|
||||
Go: &migrate.Go{
|
||||
UseTx: r.UseTx,
|
||||
UpFn: r.UpFnContext,
|
||||
UpFnNoTx: r.UpFnNoTxContext,
|
||||
DownFn: r.DownFnContext,
|
||||
DownFnNoTx: r.DownFnNoTxContext,
|
||||
},
|
||||
}
|
||||
migrations = append(migrations, m)
|
||||
migrationLookup[r.Version] = m
|
||||
}
|
||||
// Sort migrations by version in ascending order.
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
})
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func unregisteredError(unregistered []string) error {
|
||||
f := "file"
|
||||
if len(unregistered) > 1 {
|
||||
f += "s"
|
||||
}
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f))
|
||||
for _, name := range unregistered {
|
||||
b.WriteString("\t" + name + "\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations")
|
||||
|
||||
return errors.New(b.String())
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
)
|
||||
|
||||
func TestCollectFileSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
sources, err := collectFileSources(nil, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Bool(t, sources != nil, true)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
})
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
sources, err := collectFileSources(fstest.MapFS{}, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
check.Bool(t, sources != nil, true)
|
||||
})
|
||||
t.Run("incorrect_fsys", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"00000_foo.sql": sqlMapFile,
|
||||
}
|
||||
// strict disable - should not error
|
||||
sources, err := collectFileSources(mapFS, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
// strict enabled - should error
|
||||
_, err = collectFileSources(mapFS, true, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "migration version must be greater than zero")
|
||||
})
|
||||
t.Run("collect", func(t *testing.T) {
|
||||
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFileSources(fsys, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 4)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
expected := fileSources{
|
||||
sqlSources: []Source{
|
||||
{Fullpath: "00001_foo.sql", Version: 1},
|
||||
{Fullpath: "00002_bar.sql", Version: 2},
|
||||
{Fullpath: "00003_baz.sql", Version: 3},
|
||||
{Fullpath: "00110_qux.sql", Version: 110},
|
||||
},
|
||||
}
|
||||
for i := 0; i < len(sources.sqlSources); i++ {
|
||||
check.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
||||
}
|
||||
})
|
||||
t.Run("excludes", func(t *testing.T) {
|
||||
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFileSources(
|
||||
fsys,
|
||||
false,
|
||||
// exclude 2 files explicitly
|
||||
map[string]bool{
|
||||
"00002_bar.sql": true,
|
||||
"00110_qux.sql": true,
|
||||
},
|
||||
)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 2)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
expected := fileSources{
|
||||
sqlSources: []Source{
|
||||
{Fullpath: "00001_foo.sql", Version: 1},
|
||||
{Fullpath: "00003_baz.sql", Version: 3},
|
||||
},
|
||||
}
|
||||
for i := 0; i < len(sources.sqlSources); i++ {
|
||||
check.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
||||
}
|
||||
})
|
||||
t.Run("strict", func(t *testing.T) {
|
||||
mapFS := newSQLOnlyFS()
|
||||
// Add a file with no version number
|
||||
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
_, err = collectFileSources(fsys, true, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
||||
})
|
||||
t.Run("skip_go_test_files", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"1_foo.sql": sqlMapFile,
|
||||
"2_bar.sql": sqlMapFile,
|
||||
"3_baz.sql": sqlMapFile,
|
||||
"4_qux.sql": sqlMapFile,
|
||||
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
||||
}
|
||||
sources, err := collectFileSources(mapFS, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 4)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
})
|
||||
t.Run("skip_random_files", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"1_foo.sql": sqlMapFile,
|
||||
"4_something.go": {Data: []byte(`package goose`)},
|
||||
"5_qux.sql": sqlMapFile,
|
||||
"README.md": {Data: []byte(`# README`)},
|
||||
"LICENSE": {Data: []byte(`MIT`)},
|
||||
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
|
||||
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
||||
}
|
||||
sources, err := collectFileSources(mapFS, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 2)
|
||||
check.Number(t, len(sources.goSources), 1)
|
||||
// 1
|
||||
check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql")
|
||||
check.Equal(t, sources.sqlSources[0].Version, int64(1))
|
||||
// 2
|
||||
check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql")
|
||||
check.Equal(t, sources.sqlSources[1].Version, int64(5))
|
||||
// 3
|
||||
check.Equal(t, sources.goSources[0].Fullpath, "4_something.go")
|
||||
check.Equal(t, sources.goSources[0].Version, int64(4))
|
||||
})
|
||||
t.Run("duplicate_versions", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"001_foo.sql": sqlMapFile,
|
||||
"01_bar.sql": sqlMapFile,
|
||||
}
|
||||
_, err := collectFileSources(mapFS, false, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||
})
|
||||
t.Run("dirpath", func(t *testing.T) {
|
||||
mapFS := fstest.MapFS{
|
||||
"dir1/101_a.sql": sqlMapFile,
|
||||
"dir1/102_b.sql": sqlMapFile,
|
||||
"dir1/103_c.sql": sqlMapFile,
|
||||
"dir2/201_a.sql": sqlMapFile,
|
||||
"876_a.sql": sqlMapFile,
|
||||
}
|
||||
assertDirpath := func(dirpath string, sqlSources []Source) {
|
||||
t.Helper()
|
||||
f, err := fs.Sub(mapFS, dirpath)
|
||||
check.NoError(t, err)
|
||||
got, err := collectFileSources(f, false, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(got.sqlSources), len(sqlSources))
|
||||
check.Number(t, len(got.goSources), 0)
|
||||
for i := 0; i < len(got.sqlSources); i++ {
|
||||
check.Equal(t, got.sqlSources[i], sqlSources[i])
|
||||
}
|
||||
}
|
||||
assertDirpath(".", []Source{
|
||||
{Fullpath: "876_a.sql", Version: 876},
|
||||
})
|
||||
assertDirpath("dir1", []Source{
|
||||
{Fullpath: "101_a.sql", Version: 101},
|
||||
{Fullpath: "102_b.sql", Version: 102},
|
||||
{Fullpath: "103_c.sql", Version: 103},
|
||||
})
|
||||
assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}})
|
||||
assertDirpath("dir3", nil)
|
||||
})
|
||||
}
|
||||
|
||||
func newSQLOnlyFS() fstest.MapFS {
|
||||
return fstest.MapFS{
|
||||
"migrations/00001_foo.sql": sqlMapFile,
|
||||
"migrations/00002_bar.sql": sqlMapFile,
|
||||
"migrations/00003_baz.sql": sqlMapFile,
|
||||
"migrations/00110_qux.sql": sqlMapFile,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)}
|
||||
)
|
|
@ -5,22 +5,29 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/migrate"
|
||||
"github.com/pressly/goose/v3/internal/sqladapter"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
|
||||
ErrNoMigrations = errors.New("no migrations found")
|
||||
)
|
||||
|
||||
// NewProvider returns a new goose Provider.
|
||||
//
|
||||
// The caller is responsible for matching the database dialect with the database/sql driver. For
|
||||
// example, if the database dialect is "postgres", the database/sql driver could be
|
||||
// github.com/lib/pq or github.com/jackc/pgx.
|
||||
//
|
||||
// fsys is the filesystem used to read the migration files. Most users will want to use
|
||||
// os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
|
||||
// possible to use a different filesystem, such as embed.FS.
|
||||
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
|
||||
// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
|
||||
// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub.
|
||||
//
|
||||
// Functional options are used to configure the Provider. See [ProviderOption] for more information.
|
||||
// See [ProviderOption] for more information on configuring the provider.
|
||||
//
|
||||
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
|
||||
//
|
||||
|
@ -33,7 +40,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
|
|||
return nil, errors.New("dialect must not be empty")
|
||||
}
|
||||
if fsys == nil {
|
||||
return nil, errors.New("fsys must not be nil")
|
||||
fsys = noopFS{}
|
||||
}
|
||||
var cfg config
|
||||
for _, opt := range opts {
|
||||
|
@ -41,7 +48,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
// Set defaults
|
||||
// Set defaults after applying user-supplied options so option funcs can check for empty values.
|
||||
if cfg.tableName == "" {
|
||||
cfg.tableName = defaultTablename
|
||||
}
|
||||
|
@ -49,41 +56,76 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(mf): implement the rest of this function - collect sources - merge sources into
|
||||
// migrations
|
||||
// Collect migrations from the filesystem and merge with registered migrations.
|
||||
//
|
||||
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
|
||||
// lazily.
|
||||
//
|
||||
// TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to
|
||||
// return an error if there are any SQL parsing errors. This adds a bit overhead to startup
|
||||
// though, so we should make it optional.
|
||||
sources, err := collectFileSources(fsys, false, cfg.excludes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrations, err := merge(sources, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(migrations) == 0 {
|
||||
return nil, ErrNoMigrations
|
||||
}
|
||||
return &Provider{
|
||||
db: db,
|
||||
fsys: fsys,
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
db: db,
|
||||
fsys: fsys,
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
migrations: migrations,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Provider is a goose migration provider.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
type Provider struct {
|
||||
db *sql.DB
|
||||
fsys fs.FS
|
||||
cfg config
|
||||
store sqladapter.Store
|
||||
type noopFS struct{}
|
||||
|
||||
var _ fs.FS = noopFS{}
|
||||
|
||||
func (f noopFS) Open(name string) (fs.File, error) {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
// Provider is a goose migration provider.
|
||||
type Provider struct {
|
||||
db *sql.DB
|
||||
fsys fs.FS
|
||||
cfg config
|
||||
store sqladapter.Store
|
||||
migrations []*migrate.Migration
|
||||
}
|
||||
|
||||
// State represents the state of a migration.
|
||||
type State string
|
||||
|
||||
const (
|
||||
// StateUntracked represents a migration that is in the database, but not on the filesystem.
|
||||
StateUntracked State = "untracked"
|
||||
// StatePending represents a migration that is on the filesystem, but not in the database.
|
||||
StatePending State = "pending"
|
||||
// StateApplied represents a migration that is in BOTH the database and on the filesystem.
|
||||
StateApplied State = "applied"
|
||||
)
|
||||
|
||||
// MigrationStatus represents the status of a single migration.
|
||||
type MigrationStatus struct {
|
||||
// State represents the state of the migration. One of "untracked", "pending", "applied".
|
||||
// - untracked: in the database, but not on the filesystem.
|
||||
// - pending: on the filesystem, but not in the database.
|
||||
// - applied: in both the database and on the filesystem.
|
||||
State string
|
||||
// AppliedAt is the time the migration was applied. Only set if state is applied or untracked.
|
||||
// State is the state of the migration.
|
||||
State State
|
||||
// AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or
|
||||
// [StateUntracked].
|
||||
AppliedAt time.Time
|
||||
// Source is the migration source. Only set if the state is pending or applied.
|
||||
Source Source
|
||||
// Source is the migration source. Only set if the state is [StatePending] or [StateApplied].
|
||||
Source *Source
|
||||
}
|
||||
|
||||
// Status returns the status of all migrations, merging the list of migrations from the database and
|
||||
// filesystem. The returned items are ordered by version, in ascending order.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
@ -91,7 +133,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
|||
// GetDBVersion returns the max version from the database, regardless of the applied order. For
|
||||
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
|
||||
// applied, it returns 0.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
@ -111,7 +152,6 @@ const (
|
|||
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
|
||||
// the migration has a corresponding file on disk. It will be empty if the migration was registered
|
||||
// manually.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
type Source struct {
|
||||
// Type is the type of migration.
|
||||
Type SourceType
|
||||
|
@ -123,22 +163,34 @@ type Source struct {
|
|||
Version int64
|
||||
}
|
||||
|
||||
// ListSources returns a list of all available migration sources the provider is aware of.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
|
||||
// ascending order by version.
|
||||
func (p *Provider) ListSources() []*Source {
|
||||
return nil
|
||||
sources := make([]*Source, 0, len(p.migrations))
|
||||
for _, m := range p.migrations {
|
||||
s := &Source{
|
||||
Fullpath: m.Fullpath,
|
||||
Version: m.Version,
|
||||
}
|
||||
switch m.Type {
|
||||
case migrate.TypeSQL:
|
||||
s.Type = SourceTypeSQL
|
||||
case migrate.TypeGo:
|
||||
s.Type = SourceTypeGo
|
||||
}
|
||||
sources = append(sources, s)
|
||||
}
|
||||
return sources
|
||||
}
|
||||
|
||||
// Ping attempts to ping the database to verify a connection is available.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) Ping(ctx context.Context) error {
|
||||
return errors.New("not implemented")
|
||||
return p.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) Close() error {
|
||||
return errors.New("not implemented")
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// MigrationResult represents the result of a single migration.
|
||||
|
@ -150,21 +202,18 @@ type MigrationResult struct{}
|
|||
//
|
||||
// When direction is true, the up migration is executed, and when direction is false, the down
|
||||
// migration is executed.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
|
||||
// empty list and nil error.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// UpByOne applies the next available migration. If there are no migrations to apply, this method
|
||||
// returns [ErrNoNextVersion].
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
@ -174,14 +223,12 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
|||
//
|
||||
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
|
||||
// with a requested version of 10, only versions 9 and 10 will be applied.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
|
||||
// method returns [ErrNoNextVersion].
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
@ -190,7 +237,6 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
|||
//
|
||||
// For instance, if the current database version is 11, and the requested version is 9, only
|
||||
// migrations 11 and 10 will be rolled back.
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
|
|
@ -60,10 +60,24 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
|
|||
})
|
||||
}
|
||||
|
||||
// WithExcludes excludes the given file names from the list of migrations.
|
||||
//
|
||||
// If WithExcludes is called multiple times, the list of excludes is merged.
|
||||
func WithExcludes(excludes []string) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
for _, name := range excludes {
|
||||
c.excludes[name] = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type config struct {
|
||||
tableName string
|
||||
verbose bool
|
||||
excludes map[string]bool
|
||||
|
||||
// Locking options
|
||||
lockEnabled bool
|
||||
sessionLocker lock.SessionLocker
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package provider
|
||||
package provider_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/provider"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
|
@ -15,86 +15,52 @@ func TestNewProvider(t *testing.T) {
|
|||
dir := t.TempDir()
|
||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||
check.NoError(t, err)
|
||||
fsys := newFsys()
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
// Empty dialect not allowed
|
||||
_, err = NewProvider("", db, fsys)
|
||||
check.HasError(t, err)
|
||||
// Invalid dialect not allowed
|
||||
_, err = NewProvider("unknown-dialect", db, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil db not allowed
|
||||
_, err = NewProvider("sqlite3", nil, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil fsys not allowed
|
||||
_, err = NewProvider("sqlite3", db, nil)
|
||||
check.HasError(t, err)
|
||||
// Duplicate table name not allowed
|
||||
_, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"), WithTableName("bar"))
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, `table already set to "foo"`, err.Error())
|
||||
// Empty table name not allowed
|
||||
_, err = NewProvider("sqlite3", db, fsys, WithTableName(""))
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, "table must not be empty", err.Error())
|
||||
})
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
// Valid dialect, db, and fsys allowed
|
||||
_, err = NewProvider("sqlite3", db, fsys)
|
||||
check.NoError(t, err)
|
||||
// Valid dialect, db, fsys, and table name allowed
|
||||
_, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"))
|
||||
check.NoError(t, err)
|
||||
// Valid dialect, db, fsys, and verbose allowed
|
||||
_, err = NewProvider("sqlite3", db, fsys, WithVerbose())
|
||||
check.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func newFsys() fs.FS {
|
||||
return fstest.MapFS{
|
||||
fsys := fstest.MapFS{
|
||||
"1_foo.sql": {Data: []byte(migration1)},
|
||||
"2_bar.sql": {Data: []byte(migration2)},
|
||||
"3_baz.sql": {Data: []byte(migration3)},
|
||||
"4_qux.sql": {Data: []byte(migration4)},
|
||||
}
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
// Empty dialect not allowed
|
||||
_, err = provider.NewProvider("", db, fsys)
|
||||
check.HasError(t, err)
|
||||
// Invalid dialect not allowed
|
||||
_, err = provider.NewProvider("unknown-dialect", db, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil db not allowed
|
||||
_, err = provider.NewProvider("sqlite3", nil, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil fsys not allowed
|
||||
_, err = provider.NewProvider("sqlite3", db, nil)
|
||||
check.HasError(t, err)
|
||||
// Duplicate table name not allowed
|
||||
_, err = provider.NewProvider("sqlite3", db, fsys,
|
||||
provider.WithTableName("foo"),
|
||||
provider.WithTableName("bar"),
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, `table already set to "foo"`, err.Error())
|
||||
// Empty table name not allowed
|
||||
_, err = provider.NewProvider("sqlite3", db, fsys,
|
||||
provider.WithTableName(""),
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, "table must not be empty", err.Error())
|
||||
})
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
// Valid dialect, db, and fsys allowed
|
||||
_, err = provider.NewProvider("sqlite3", db, fsys)
|
||||
check.NoError(t, err)
|
||||
// Valid dialect, db, fsys, and table name allowed
|
||||
_, err = provider.NewProvider("sqlite3", db, fsys,
|
||||
provider.WithTableName("foo"),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
// Valid dialect, db, fsys, and verbose allowed
|
||||
_, err = provider.NewProvider("sqlite3", db, fsys,
|
||||
provider.WithVerbose(),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
migration1 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE foo (id INTEGER PRIMARY KEY);
|
||||
-- +goose Down
|
||||
DROP TABLE foo;
|
||||
`
|
||||
migration2 = `
|
||||
-- +goose Up
|
||||
ALTER TABLE foo ADD COLUMN name TEXT;
|
||||
-- +goose Down
|
||||
ALTER TABLE foo DROP COLUMN name;
|
||||
`
|
||||
migration3 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE bar (
|
||||
id INTEGER PRIMARY KEY,
|
||||
description TEXT
|
||||
);
|
||||
-- +goose Down
|
||||
DROP TABLE bar;
|
||||
`
|
||||
migration4 = `
|
||||
-- +goose Up
|
||||
-- Rename the 'foo' table to 'my_foo'
|
||||
ALTER TABLE foo RENAME TO my_foo;
|
||||
|
||||
-- Add a new column 'timestamp' to 'my_foo'
|
||||
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
|
||||
|
||||
-- +goose Down
|
||||
-- Remove the 'timestamp' column from 'my_foo'
|
||||
ALTER TABLE my_foo DROP COLUMN timestamp;
|
||||
|
||||
-- Rename the 'my_foo' table back to 'foo'
|
||||
ALTER TABLE my_foo RENAME TO foo;
|
||||
`
|
||||
)
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
package provider_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/provider"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||
check.NoError(t, err)
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
_, err := provider.NewProvider("sqlite3", db, fstest.MapFS{})
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
|
||||
})
|
||||
|
||||
mapFS := fstest.MapFS{
|
||||
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
|
||||
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
||||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
p, err := provider.NewProvider("sqlite3", db, fsys)
|
||||
check.NoError(t, err)
|
||||
sources := p.ListSources()
|
||||
check.Equal(t, len(sources), 2)
|
||||
// 1
|
||||
check.Equal(t, sources[0].Version, int64(1))
|
||||
check.Equal(t, sources[0].Fullpath, "001_foo.sql")
|
||||
check.Equal(t, sources[0].Type, provider.SourceTypeSQL)
|
||||
// 2
|
||||
check.Equal(t, sources[1].Version, int64(2))
|
||||
check.Equal(t, sources[1].Fullpath, "002_bar.sql")
|
||||
check.Equal(t, sources[1].Type, provider.SourceTypeSQL)
|
||||
}
|
||||
|
||||
var (
|
||||
migration1 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE foo (id INTEGER PRIMARY KEY);
|
||||
-- +goose Down
|
||||
DROP TABLE foo;
|
||||
`
|
||||
migration2 = `
|
||||
-- +goose Up
|
||||
ALTER TABLE foo ADD COLUMN name TEXT;
|
||||
-- +goose Down
|
||||
ALTER TABLE foo DROP COLUMN name;
|
||||
`
|
||||
migration3 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE bar (
|
||||
id INTEGER PRIMARY KEY,
|
||||
description TEXT
|
||||
);
|
||||
-- +goose Down
|
||||
DROP TABLE bar;
|
||||
`
|
||||
migration4 = `
|
||||
-- +goose Up
|
||||
-- Rename the 'foo' table to 'my_foo'
|
||||
ALTER TABLE foo RENAME TO my_foo;
|
||||
|
||||
-- Add a new column 'timestamp' to 'my_foo'
|
||||
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
|
||||
|
||||
-- +goose Down
|
||||
-- Remove the 'timestamp' column from 'my_foo'
|
||||
ALTER TABLE my_foo DROP COLUMN timestamp;
|
||||
|
||||
-- Rename the 'my_foo' table back to 'foo'
|
||||
ALTER TABLE my_foo RENAME TO foo;
|
||||
`
|
||||
)
|
|
@ -14,19 +14,20 @@ import (
|
|||
)
|
||||
|
||||
func TestPostgresSessionLocker(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.Skip("skip long running test")
|
||||
}
|
||||
db, cleanup, err := testdb.NewPostgres()
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
const (
|
||||
lockID int64 = 123456789
|
||||
)
|
||||
|
||||
// Do not run tests in parallel, because they are using the same database.
|
||||
|
||||
t.Run("lock_and_unlock", func(t *testing.T) {
|
||||
const (
|
||||
lockID int64 = 123456789
|
||||
)
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockID(lockID),
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
|
|
28
migration.go
28
migration.go
|
@ -218,27 +218,27 @@ func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, d
|
|||
return store.DeleteVersionNoTx(ctx, db, TableName(), version)
|
||||
}
|
||||
|
||||
// NumericComponent looks for migration scripts with names in the form:
|
||||
// XXX_descriptivename.ext where XXX specifies the version number
|
||||
// and ext specifies the type of migration
|
||||
func NumericComponent(name string) (int64, error) {
|
||||
base := filepath.Base(name)
|
||||
|
||||
// NumericComponent parses the version from the migration file name.
|
||||
//
|
||||
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
|
||||
// migration, either .sql or .go.
|
||||
func NumericComponent(filename string) (int64, error) {
|
||||
base := filepath.Base(filename)
|
||||
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
|
||||
return 0, errors.New("not a recognized migration file type")
|
||||
return 0, errors.New("migration file does not have .sql or .go file extension")
|
||||
}
|
||||
|
||||
idx := strings.Index(base, "_")
|
||||
if idx < 0 {
|
||||
return 0, errors.New("no filename separator '_' found")
|
||||
}
|
||||
|
||||
n, e := strconv.ParseInt(base[:idx], 10, 64)
|
||||
if e == nil && n <= 0 {
|
||||
return 0, errors.New("migration IDs must be greater than zero")
|
||||
n, err := strconv.ParseInt(base[:idx], 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n, e
|
||||
if n < 1 {
|
||||
return 0, errors.New("migration version must be greater than zero")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func truncateDuration(d time.Duration) time.Duration {
|
||||
|
|
Loading…
Reference in New Issue