feat(experimental): add collect migrations logic and new Provider options (#615)

pull/617/head
Michael Fridman 2023-10-14 09:30:01 -04:00 committed by GitHub
parent fe8fe975d8
commit 68853f91ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 621 additions and 139 deletions

View File

@ -33,6 +33,9 @@ tools:
test-packages:
go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
test-packages-short:
go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
test-e2e: test-e2e-postgres test-e2e-mysql test-e2e-clickhouse test-e2e-vertica
test-e2e-postgres:

View File

@ -11,6 +11,9 @@ import (
func TestSequential(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skip long running test")
}
dir := t.TempDir()
defer os.Remove("./bin/create-goose") // clean up

View File

@ -11,6 +11,9 @@ import (
func TestFix(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skip long running test")
}
dir := t.TempDir()
defer os.Remove("./bin/fix-goose") // clean up

View File

@ -8,6 +8,7 @@ import (
)
func TestParsingGoMigrations(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
@ -38,6 +39,7 @@ func TestParsingGoMigrations(t *testing.T) {
}
func TestParsingGoMigrationsError(t *testing.T) {
t.Parallel()
_, err := parseGoFile(strings.NewReader(emptyInit))
check.HasError(t, err)
check.Contains(t, err.Error(), "no registered goose functions")

View File

@ -0,0 +1,176 @@
package provider
import (
"errors"
"fmt"
"io/fs"
"path/filepath"
"sort"
"strings"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/migrate"
)
// fileSources represents a collection of migration files on the filesystem.
type fileSources struct {
sqlSources []Source
goSources []Source
}
// collectFileSources scans the file system for migration files that have a numeric prefix (greater
// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil,
// in which case an empty fileSources is returned.
//
// If strict is true, then any error parsing the numeric component of the filename will result in an
// error. The file is skipped otherwise.
//
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
// migration sources from the filesystem.
func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
if fsys == nil {
return new(fileSources), nil
}
sources := new(fileSources)
versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath)
for _, pattern := range []string{
"*.sql",
"*.go",
} {
files, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
}
for _, fullpath := range files {
base := filepath.Base(fullpath)
// Skip explicit excludes or Go test files.
if excludes[base] || strings.HasSuffix(base, "_test.go") {
continue
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := goose.NumericComponent(base)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
continue
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing,
base,
)
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, Source{
Fullpath: fullpath,
Version: version,
})
case ".go":
sources.goSources = append(sources.goSources, Source{
Fullpath: fullpath,
Version: version,
})
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("unknown migration type: %s", base)
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
}
}
return sources, nil
}
func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) {
var migrations []*migrate.Migration
migrationLookup := make(map[int64]*migrate.Migration)
// Add all SQL migrations to the list of migrations.
for _, s := range sources.sqlSources {
m := &migrate.Migration{
Type: migrate.TypeSQL,
Fullpath: s.Fullpath,
Version: s.Version,
SQLParsed: false,
}
migrations = append(migrations, m)
migrationLookup[s.Version] = m
}
// If there are no Go files in the filesystem and no registered Go migrations, return early.
if len(sources.goSources) == 0 && len(registerd) == 0 {
return migrations, nil
}
// Return an error if the given sources contain a versioned Go migration that has not been
// registered. This is a sanity check to ensure users didn't accidentally create a valid looking
// Go migration file on disk and forget to register it.
//
// This is almost always a user error.
var unregistered []string
for _, s := range sources.goSources {
if _, ok := registerd[s.Version]; !ok {
unregistered = append(unregistered, s.Fullpath)
}
}
if len(unregistered) > 0 {
return nil, unregisteredError(unregistered)
}
// Add all registered Go migrations to the list of migrations, checking for duplicate versions.
//
// Important, users can register Go migrations manually via goose.Add_ functions. These
// migrations may not have a corresponding file on disk. Which is fine! We include them
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
// the SQL migration files.
for _, r := range registerd {
// Ensure there are no duplicate versions.
if existing, ok := migrationLookup[r.Version]; ok {
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
r.Version,
existing,
filepath.Base(r.Source),
)
}
m := &migrate.Migration{
Fullpath: r.Source, // May be empty if the migration was registered manually.
Version: r.Version,
Type: migrate.TypeGo,
Go: &migrate.Go{
UseTx: r.UseTx,
UpFn: r.UpFnContext,
UpFnNoTx: r.UpFnNoTxContext,
DownFn: r.DownFnContext,
DownFnNoTx: r.DownFnNoTxContext,
},
}
migrations = append(migrations, m)
migrationLookup[r.Version] = m
}
// Sort migrations by version in ascending order.
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
func unregisteredError(unregistered []string) error {
f := "file"
if len(unregistered) > 1 {
f += "s"
}
var b strings.Builder
b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f))
for _, name := range unregistered {
b.WriteString("\t" + name + "\n")
}
b.WriteString("\n")
b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations")
return errors.New(b.String())
}

View File

@ -0,0 +1,185 @@
package provider
import (
"io/fs"
"testing"
"testing/fstest"
"github.com/pressly/goose/v3/internal/check"
)
func TestCollectFileSources(t *testing.T) {
t.Parallel()
t.Run("nil", func(t *testing.T) {
sources, err := collectFileSources(nil, false, nil)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("empty", func(t *testing.T) {
sources, err := collectFileSources(fstest.MapFS{}, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
check.Bool(t, sources != nil, true)
})
t.Run("incorrect_fsys", func(t *testing.T) {
mapFS := fstest.MapFS{
"00000_foo.sql": sqlMapFile,
}
// strict disable - should not error
sources, err := collectFileSources(mapFS, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
// strict enabled - should error
_, err = collectFileSources(mapFS, true, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration version must be greater than zero")
})
t.Run("collect", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFileSources(fsys, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
{Fullpath: "00001_foo.sql", Version: 1},
{Fullpath: "00002_bar.sql", Version: 2},
{Fullpath: "00003_baz.sql", Version: 3},
{Fullpath: "00110_qux.sql", Version: 110},
},
}
for i := 0; i < len(sources.sqlSources); i++ {
check.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
}
})
t.Run("excludes", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFileSources(
fsys,
false,
// exclude 2 files explicitly
map[string]bool{
"00002_bar.sql": true,
"00110_qux.sql": true,
},
)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
{Fullpath: "00001_foo.sql", Version: 1},
{Fullpath: "00003_baz.sql", Version: 3},
},
}
for i := 0; i < len(sources.sqlSources); i++ {
check.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
}
})
t.Run("strict", func(t *testing.T) {
mapFS := newSQLOnlyFS()
// Add a file with no version number
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
_, err = collectFileSources(fsys, true, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
})
t.Run("skip_go_test_files", func(t *testing.T) {
mapFS := fstest.MapFS{
"1_foo.sql": sqlMapFile,
"2_bar.sql": sqlMapFile,
"3_baz.sql": sqlMapFile,
"4_qux.sql": sqlMapFile,
"5_foo_test.go": {Data: []byte(`package goose_test`)},
}
sources, err := collectFileSources(mapFS, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
})
t.Run("skip_random_files", func(t *testing.T) {
mapFS := fstest.MapFS{
"1_foo.sql": sqlMapFile,
"4_something.go": {Data: []byte(`package goose`)},
"5_qux.sql": sqlMapFile,
"README.md": {Data: []byte(`# README`)},
"LICENSE": {Data: []byte(`MIT`)},
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
}
sources, err := collectFileSources(mapFS, false, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1)
// 1
check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql")
check.Equal(t, sources.sqlSources[0].Version, int64(1))
// 2
check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql")
check.Equal(t, sources.sqlSources[1].Version, int64(5))
// 3
check.Equal(t, sources.goSources[0].Fullpath, "4_something.go")
check.Equal(t, sources.goSources[0].Version, int64(4))
})
t.Run("duplicate_versions", func(t *testing.T) {
mapFS := fstest.MapFS{
"001_foo.sql": sqlMapFile,
"01_bar.sql": sqlMapFile,
}
_, err := collectFileSources(mapFS, false, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
})
t.Run("dirpath", func(t *testing.T) {
mapFS := fstest.MapFS{
"dir1/101_a.sql": sqlMapFile,
"dir1/102_b.sql": sqlMapFile,
"dir1/103_c.sql": sqlMapFile,
"dir2/201_a.sql": sqlMapFile,
"876_a.sql": sqlMapFile,
}
assertDirpath := func(dirpath string, sqlSources []Source) {
t.Helper()
f, err := fs.Sub(mapFS, dirpath)
check.NoError(t, err)
got, err := collectFileSources(f, false, nil)
check.NoError(t, err)
check.Number(t, len(got.sqlSources), len(sqlSources))
check.Number(t, len(got.goSources), 0)
for i := 0; i < len(got.sqlSources); i++ {
check.Equal(t, got.sqlSources[i], sqlSources[i])
}
}
assertDirpath(".", []Source{
{Fullpath: "876_a.sql", Version: 876},
})
assertDirpath("dir1", []Source{
{Fullpath: "101_a.sql", Version: 101},
{Fullpath: "102_b.sql", Version: 102},
{Fullpath: "103_c.sql", Version: 103},
})
assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}})
assertDirpath("dir3", nil)
})
}
func newSQLOnlyFS() fstest.MapFS {
return fstest.MapFS{
"migrations/00001_foo.sql": sqlMapFile,
"migrations/00002_bar.sql": sqlMapFile,
"migrations/00003_baz.sql": sqlMapFile,
"migrations/00110_qux.sql": sqlMapFile,
}
}
var (
sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)}
)

View File

@ -5,22 +5,29 @@ import (
"database/sql"
"errors"
"io/fs"
"os"
"time"
"github.com/pressly/goose/v3/internal/migrate"
"github.com/pressly/goose/v3/internal/sqladapter"
)
var (
// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
ErrNoMigrations = errors.New("no migrations found")
)
// NewProvider returns a new goose Provider.
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
// example, if the database dialect is "postgres", the database/sql driver could be
// github.com/lib/pq or github.com/jackc/pgx.
//
// fsys is the filesystem used to read the migration files. Most users will want to use
// os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
// possible to use a different filesystem, such as embed.FS.
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub.
//
// Functional options are used to configure the Provider. See [ProviderOption] for more information.
// See [ProviderOption] for more information on configuring the provider.
//
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
@ -33,7 +40,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
return nil, errors.New("dialect must not be empty")
}
if fsys == nil {
return nil, errors.New("fsys must not be nil")
fsys = noopFS{}
}
var cfg config
for _, opt := range opts {
@ -41,7 +48,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
return nil, err
}
}
// Set defaults
// Set defaults after applying user-supplied options so option funcs can check for empty values.
if cfg.tableName == "" {
cfg.tableName = defaultTablename
}
@ -49,41 +56,76 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption)
if err != nil {
return nil, err
}
// TODO(mf): implement the rest of this function - collect sources - merge sources into
// migrations
// Collect migrations from the filesystem and merge with registered migrations.
//
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
// lazily.
//
// TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to
// return an error if there are any SQL parsing errors. This adds a bit overhead to startup
// though, so we should make it optional.
sources, err := collectFileSources(fsys, false, cfg.excludes)
if err != nil {
return nil, err
}
migrations, err := merge(sources, nil)
if err != nil {
return nil, err
}
if len(migrations) == 0 {
return nil, ErrNoMigrations
}
return &Provider{
db: db,
fsys: fsys,
cfg: cfg,
store: store,
db: db,
fsys: fsys,
cfg: cfg,
store: store,
migrations: migrations,
}, nil
}
// Provider is a goose migration provider.
// Experimental: This API is experimental and may change in the future.
type Provider struct {
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
type noopFS struct{}
var _ fs.FS = noopFS{}
func (f noopFS) Open(name string) (fs.File, error) {
return nil, os.ErrNotExist
}
// Provider is a goose migration provider.
type Provider struct {
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
migrations []*migrate.Migration
}
// State represents the state of a migration.
type State string
const (
// StateUntracked represents a migration that is in the database, but not on the filesystem.
StateUntracked State = "untracked"
// StatePending represents a migration that is on the filesystem, but not in the database.
StatePending State = "pending"
// StateApplied represents a migration that is in BOTH the database and on the filesystem.
StateApplied State = "applied"
)
// MigrationStatus represents the status of a single migration.
type MigrationStatus struct {
// State represents the state of the migration. One of "untracked", "pending", "applied".
// - untracked: in the database, but not on the filesystem.
// - pending: on the filesystem, but not in the database.
// - applied: in both the database and on the filesystem.
State string
// AppliedAt is the time the migration was applied. Only set if state is applied or untracked.
// State is the state of the migration.
State State
// AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or
// [StateUntracked].
AppliedAt time.Time
// Source is the migration source. Only set if the state is pending or applied.
Source Source
// Source is the migration source. Only set if the state is [StatePending] or [StateApplied].
Source *Source
}
// Status returns the status of all migrations, merging the list of migrations from the database and
// filesystem. The returned items are ordered by version, in ascending order.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
return nil, errors.New("not implemented")
}
@ -91,7 +133,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
// GetDBVersion returns the max version from the database, regardless of the applied order. For
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
// applied, it returns 0.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
@ -111,7 +152,6 @@ const (
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
// the migration has a corresponding file on disk. It will be empty if the migration was registered
// manually.
// Experimental: This API is experimental and may change in the future.
type Source struct {
// Type is the type of migration.
Type SourceType
@ -123,22 +163,34 @@ type Source struct {
Version int64
}
// ListSources returns a list of all available migration sources the provider is aware of.
// Experimental: This API is experimental and may change in the future.
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
// ascending order by version.
func (p *Provider) ListSources() []*Source {
return nil
sources := make([]*Source, 0, len(p.migrations))
for _, m := range p.migrations {
s := &Source{
Fullpath: m.Fullpath,
Version: m.Version,
}
switch m.Type {
case migrate.TypeSQL:
s.Type = SourceTypeSQL
case migrate.TypeGo:
s.Type = SourceTypeGo
}
sources = append(sources, s)
}
return sources
}
// Ping attempts to ping the database to verify a connection is available.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) Ping(ctx context.Context) error {
return errors.New("not implemented")
return p.db.PingContext(ctx)
}
// Close closes the database connection.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) Close() error {
return errors.New("not implemented")
return p.db.Close()
}
// MigrationResult represents the result of a single migration.
@ -150,21 +202,18 @@ type MigrationResult struct{}
//
// When direction is true, the up migration is executed, and when direction is false, the down
// migration is executed.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
return nil, errors.New("not implemented")
}
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
// empty list and nil error.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
return nil, errors.New("not implemented")
}
// UpByOne applies the next available migration. If there are no migrations to apply, this method
// returns [ErrNoNextVersion].
// Experimental: This API is experimental and may change in the future.
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
return nil, errors.New("not implemented")
}
@ -174,14 +223,12 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
//
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9 and 10 will be applied.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
return nil, errors.New("not implemented")
}
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
// method returns [ErrNoNextVersion].
// Experimental: This API is experimental and may change in the future.
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
return nil, errors.New("not implemented")
}
@ -190,7 +237,6 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
//
// For instance, if the current database version is 11, and the requested version is 9, only
// migrations 11 and 10 will be rolled back.
// Experimental: This API is experimental and may change in the future.
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
return nil, errors.New("not implemented")
}

View File

@ -60,10 +60,24 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
})
}
// WithExcludes excludes the given file names from the list of migrations.
//
// If WithExcludes is called multiple times, the list of excludes is merged.
func WithExcludes(excludes []string) ProviderOption {
return configFunc(func(c *config) error {
for _, name := range excludes {
c.excludes[name] = true
}
return nil
})
}
type config struct {
tableName string
verbose bool
excludes map[string]bool
// Locking options
lockEnabled bool
sessionLocker lock.SessionLocker
}

View File

@ -1,13 +1,13 @@
package provider
package provider_test
import (
"database/sql"
"io/fs"
"path/filepath"
"testing"
"testing/fstest"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
_ "modernc.org/sqlite"
)
@ -15,86 +15,52 @@ func TestNewProvider(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
fsys := newFsys()
t.Run("invalid", func(t *testing.T) {
// Empty dialect not allowed
_, err = NewProvider("", db, fsys)
check.HasError(t, err)
// Invalid dialect not allowed
_, err = NewProvider("unknown-dialect", db, fsys)
check.HasError(t, err)
// Nil db not allowed
_, err = NewProvider("sqlite3", nil, fsys)
check.HasError(t, err)
// Nil fsys not allowed
_, err = NewProvider("sqlite3", db, nil)
check.HasError(t, err)
// Duplicate table name not allowed
_, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"), WithTableName("bar"))
check.HasError(t, err)
check.Equal(t, `table already set to "foo"`, err.Error())
// Empty table name not allowed
_, err = NewProvider("sqlite3", db, fsys, WithTableName(""))
check.HasError(t, err)
check.Equal(t, "table must not be empty", err.Error())
})
t.Run("valid", func(t *testing.T) {
// Valid dialect, db, and fsys allowed
_, err = NewProvider("sqlite3", db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and table name allowed
_, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"))
check.NoError(t, err)
// Valid dialect, db, fsys, and verbose allowed
_, err = NewProvider("sqlite3", db, fsys, WithVerbose())
check.NoError(t, err)
})
}
func newFsys() fs.FS {
return fstest.MapFS{
fsys := fstest.MapFS{
"1_foo.sql": {Data: []byte(migration1)},
"2_bar.sql": {Data: []byte(migration2)},
"3_baz.sql": {Data: []byte(migration3)},
"4_qux.sql": {Data: []byte(migration4)},
}
t.Run("invalid", func(t *testing.T) {
// Empty dialect not allowed
_, err = provider.NewProvider("", db, fsys)
check.HasError(t, err)
// Invalid dialect not allowed
_, err = provider.NewProvider("unknown-dialect", db, fsys)
check.HasError(t, err)
// Nil db not allowed
_, err = provider.NewProvider("sqlite3", nil, fsys)
check.HasError(t, err)
// Nil fsys not allowed
_, err = provider.NewProvider("sqlite3", db, nil)
check.HasError(t, err)
// Duplicate table name not allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName("foo"),
provider.WithTableName("bar"),
)
check.HasError(t, err)
check.Equal(t, `table already set to "foo"`, err.Error())
// Empty table name not allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName(""),
)
check.HasError(t, err)
check.Equal(t, "table must not be empty", err.Error())
})
t.Run("valid", func(t *testing.T) {
// Valid dialect, db, and fsys allowed
_, err = provider.NewProvider("sqlite3", db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and table name allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName("foo"),
)
check.NoError(t, err)
// Valid dialect, db, fsys, and verbose allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithVerbose(),
)
check.NoError(t, err)
})
}
var (
migration1 = `
-- +goose Up
CREATE TABLE foo (id INTEGER PRIMARY KEY);
-- +goose Down
DROP TABLE foo;
`
migration2 = `
-- +goose Up
ALTER TABLE foo ADD COLUMN name TEXT;
-- +goose Down
ALTER TABLE foo DROP COLUMN name;
`
migration3 = `
-- +goose Up
CREATE TABLE bar (
id INTEGER PRIMARY KEY,
description TEXT
);
-- +goose Down
DROP TABLE bar;
`
migration4 = `
-- +goose Up
-- Rename the 'foo' table to 'my_foo'
ALTER TABLE foo RENAME TO my_foo;
-- Add a new column 'timestamp' to 'my_foo'
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
-- +goose Down
-- Remove the 'timestamp' column from 'my_foo'
ALTER TABLE my_foo DROP COLUMN timestamp;
-- Rename the 'my_foo' table back to 'foo'
ALTER TABLE my_foo RENAME TO foo;
`
)

View File

@ -0,0 +1,83 @@
package provider_test
import (
"database/sql"
"errors"
"io/fs"
"path/filepath"
"testing"
"testing/fstest"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
_ "modernc.org/sqlite"
)
func TestProvider(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
t.Run("empty", func(t *testing.T) {
_, err := provider.NewProvider("sqlite3", db, fstest.MapFS{})
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
})
mapFS := fstest.MapFS{
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
p, err := provider.NewProvider("sqlite3", db, fsys)
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
// 1
check.Equal(t, sources[0].Version, int64(1))
check.Equal(t, sources[0].Fullpath, "001_foo.sql")
check.Equal(t, sources[0].Type, provider.SourceTypeSQL)
// 2
check.Equal(t, sources[1].Version, int64(2))
check.Equal(t, sources[1].Fullpath, "002_bar.sql")
check.Equal(t, sources[1].Type, provider.SourceTypeSQL)
}
var (
migration1 = `
-- +goose Up
CREATE TABLE foo (id INTEGER PRIMARY KEY);
-- +goose Down
DROP TABLE foo;
`
migration2 = `
-- +goose Up
ALTER TABLE foo ADD COLUMN name TEXT;
-- +goose Down
ALTER TABLE foo DROP COLUMN name;
`
migration3 = `
-- +goose Up
CREATE TABLE bar (
id INTEGER PRIMARY KEY,
description TEXT
);
-- +goose Down
DROP TABLE bar;
`
migration4 = `
-- +goose Up
-- Rename the 'foo' table to 'my_foo'
ALTER TABLE foo RENAME TO my_foo;
-- Add a new column 'timestamp' to 'my_foo'
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
-- +goose Down
-- Remove the 'timestamp' column from 'my_foo'
ALTER TABLE my_foo DROP COLUMN timestamp;
-- Rename the 'my_foo' table back to 'foo'
ALTER TABLE my_foo RENAME TO foo;
`
)

View File

@ -14,19 +14,20 @@ import (
)
func TestPostgresSessionLocker(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skip long running test")
}
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
const (
lockID int64 = 123456789
)
// Do not run tests in parallel, because they are using the same database.
t.Run("lock_and_unlock", func(t *testing.T) {
const (
lockID int64 = 123456789
)
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
lock.WithLockTimeout(4*time.Second),

View File

@ -218,27 +218,27 @@ func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, d
return store.DeleteVersionNoTx(ctx, db, TableName(), version)
}
// NumericComponent looks for migration scripts with names in the form:
// XXX_descriptivename.ext where XXX specifies the version number
// and ext specifies the type of migration
func NumericComponent(name string) (int64, error) {
base := filepath.Base(name)
// NumericComponent parses the version from the migration file name.
//
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
// migration, either .sql or .go.
func NumericComponent(filename string) (int64, error) {
base := filepath.Base(filename)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
return 0, errors.New("migration file does not have .sql or .go file extension")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no filename separator '_' found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, err
}
return n, e
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
}
return n, nil
}
func truncateDuration(d time.Duration) time.Duration {