feat: Add goose provider (#635)

pull/637/head
Michael Fridman 2023-11-09 09:23:37 -05:00 committed by GitHub
parent 8503d4e20b
commit 04e12b88f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1537 additions and 1702 deletions

View File

@ -2,6 +2,7 @@ package database
import (
"context"
"database/sql"
"errors"
"fmt"
@ -100,6 +101,9 @@ func (s *store) GetMigration(
&result.Timestamp,
&result.IsApplied,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
}
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return &result, nil

View File

@ -2,9 +2,15 @@ package database
import (
"context"
"errors"
"time"
)
var (
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
ErrVersionNotFound = errors.New("version not found")
)
// Store is an interface that defines methods for managing database migrations and versioning. By
// defining a Store interface, we can support multiple databases with consistent functionality.
//
@ -24,8 +30,8 @@ type Store interface {
// Delete deletes a version id from the version table.
Delete(ctx context.Context, db DBTxConn, version int64) error
// GetMigration retrieves a single migration by version id. This method may return the raw sql
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
// GetMigration retrieves a single migration by version id. If the query succeeds, but the
// version is not found, this method must return [ErrVersionNotFound].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If

View File

@ -205,7 +205,7 @@ func testStore(
err = runConn(ctx, db, func(conn *sql.Conn) error {
_, err := store.GetMigration(ctx, conn, 0)
check.HasError(t, err)
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
return nil
})
check.NoError(t, err)

View File

@ -22,13 +22,12 @@ func ResetGlobalMigrations() {
// [NewGoMigration] function.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, migration := range migrations {
m := &migration
func SetGlobalMigrations(migrations ...*Migration) error {
for _, m := range migrations {
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
if err := checkMigration(m); err != nil {
if err := checkGoMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
registeredGoMigrations[m.Version] = m
@ -36,7 +35,7 @@ func SetGlobalMigrations(migrations ...Migration) error {
return nil
}
func checkMigration(m *Migration) error {
func checkGoMigration(m *Migration) error {
if !m.construct {
return errors.New("must use NewGoMigration to construct migrations")
}
@ -63,10 +62,10 @@ func checkMigration(m *Migration) error {
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
}
}
if err := setGoFunc(m.goUp); err != nil {
if err := checkGoFunc(m.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
}
if err := setGoFunc(m.goDown); err != nil {
if err := checkGoFunc(m.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
@ -84,47 +83,22 @@ func checkMigration(m *Migration) error {
return nil
}
func setGoFunc(f *GoFunc) error {
if f == nil {
f = &GoFunc{Mode: TransactionEnabled}
return nil
}
func checkGoFunc(f *GoFunc) error {
if f.RunTx != nil && f.RunDB != nil {
return errors.New("must specify exactly one of RunTx or RunDB")
}
if f.RunTx == nil && f.RunDB == nil {
switch f.Mode {
case 0:
// Default to TransactionEnabled ONLY if mode is not set explicitly.
f.Mode = TransactionEnabled
case TransactionEnabled, TransactionDisabled:
// No functions but mode is set. This is not an error. It means the user wants to record
// a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
return nil
switch f.Mode {
case TransactionEnabled, TransactionDisabled:
// No functions, but mode is set. This is not an error. It means the user wants to
// record a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
if f.RunDB != nil {
switch f.Mode {
case 0, TransactionDisabled:
f.Mode = TransactionDisabled
default:
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
if f.RunDB != nil && f.Mode != TransactionDisabled {
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
if f.RunTx != nil {
switch f.Mode {
case 0, TransactionEnabled:
f.Mode = TransactionEnabled
default:
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
}
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
// the functions or return an error. This should never happen.
if f.Mode == 0 {
return errors.New("failed to infer transaction mode")
if f.RunTx != nil && f.Mode != TransactionEnabled {
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
return nil
}

View File

@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) {
// reset so we can check the default is set
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
err = SetGlobalMigrations(migration2)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 2)
registered = registeredGoMigrations[2]
check.Bool(t, registered.goUp != nil, true)
check.Bool(t, registered.goDown != nil, true)
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")
migration3 := NewGoMigration(3, nil, nil)
// reset so we can check the default is set
migration3.goDown.Mode = 0
err = SetGlobalMigrations(migration3)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
})
t.Run("unknown_mode", func(t *testing.T) {
m := NewGoMigration(1, nil, nil)
@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) {
runTx := func(context.Context, *sql.Tx) error { return nil }
// Success.
err := SetGlobalMigrations([]Migration{}...)
err := SetGlobalMigrations([]*Migration{}...)
check.NoError(t, err)
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) {
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
}
func TestCheckMigration(t *testing.T) {
// Failures.
err := checkMigration(&Migration{})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
err = checkMigration(&Migration{construct: true})
check.HasError(t, err)
check.Contains(t, err.Error(), "must be registered")
err = checkMigration(&Migration{construct: true, Registered: true})
check.HasError(t, err)
check.Contains(t, err.Error(), `type must be "go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "version must be greater than zero")
// Success.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
err := checkGoMigration(NewGoMigration(1, nil, nil))
check.NoError(t, err)
// Failures.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
err = checkGoMigration(&Migration{})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
err = checkGoMigration(&Migration{construct: true})
check.HasError(t, err)
check.Contains(t, err.Error(), "must be registered")
err = checkGoMigration(&Migration{construct: true, Registered: true})
check.HasError(t, err)
check.Contains(t, err.Error(), `type must be "go"`)
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "version must be greater than zero")
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
check.HasError(t, err)
check.Contains(t, err.Error(), "up function: invalid mode: 0")
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
check.HasError(t, err)
check.Contains(t, err.Error(), "down function: invalid mode: 0")
// Success.
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
check.NoError(t, err)
// Failures.
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `no filename separator '_' found`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFn: func(*sql.Tx) error { return nil },
UpFnNoTx: func(*sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFn: func(*sql.Tx) error { return nil },
DownFnNoTx: func(*sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")

View File

@ -1,186 +0,0 @@
package provider
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"github.com/pressly/goose/v3/database"
)
type migration struct {
Source Source
// A migration is either a Go migration or a SQL migration, but never both.
//
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
// majority of the time migrations are incremental, so it is likely that the user will only want
// to run the last few migrations, and there is no need to parse ALL prior migrations.
//
// Exactly one of these fields will be set:
Go *goMigration
// -- OR --
SQL *sqlMigration
}
func (m *migration) useTx(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL.UseTx
case TypeGo:
if m.Go == nil || m.Go.isEmpty(direction) {
return false
}
if direction {
return m.Go.up.Run != nil
}
return m.Go.down.Run != nil
}
// This should never happen.
return false
}
func (m *migration) isEmpty(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL == nil || m.SQL.isEmpty(direction)
case TypeGo:
return m.Go == nil || m.Go.isEmpty(direction)
}
return true
}
func (m *migration) filename() string {
return filepath.Base(m.Source.Path)
}
// run runs the migration inside of a transaction.
func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil {
return fmt.Errorf("tx: sql migration has not been parsed")
}
return m.SQL.run(ctx, tx, direction)
case TypeGo:
return m.Go.run(ctx, tx, direction)
}
// This should never happen.
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
// runNoTx runs the migration without a transaction.
func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil {
return fmt.Errorf("db: sql migration has not been parsed")
}
return m.SQL.run(ctx, db, direction)
case TypeGo:
return m.Go.runNoTx(ctx, db, direction)
}
// This should never happen.
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
// runConn runs the migration without a transaction using the provided connection.
func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil {
return fmt.Errorf("conn: sql migration has not been parsed")
}
return m.SQL.run(ctx, conn, direction)
case TypeGo:
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
}
// This should never happen.
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
type goMigration struct {
fullpath string
up, down *GoMigrationFunc
}
func (g *goMigration) isEmpty(direction bool) bool {
if g.up == nil && g.down == nil {
panic("go migration has no up or down")
}
if direction {
return g.up == nil
}
return g.down == nil
}
func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration {
return &goMigration{
fullpath: fullpath,
up: up,
down: down,
}
}
func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
if g == nil {
return nil
}
var fn func(context.Context, *sql.Tx) error
if direction && g.up != nil {
fn = g.up.Run
}
if !direction && g.down != nil {
fn = g.down.Run
}
if fn != nil {
return fn(ctx, tx)
}
return nil
}
func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
if g == nil {
return nil
}
var fn func(context.Context, *sql.DB) error
if direction && g.up != nil {
fn = g.up.RunNoTx
}
if !direction && g.down != nil {
fn = g.down.RunNoTx
}
if fn != nil {
return fn(ctx, db)
}
return nil
}
type sqlMigration struct {
UseTx bool
UpStatements []string
DownStatements []string
}
func (s *sqlMigration) isEmpty(direction bool) bool {
if direction {
return len(s.UpStatements) == 0
}
return len(s.DownStatements) == 0
}
func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
} else {
statements = s.DownStatements
}
for _, stmt := range statements {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}

View File

@ -1,68 +0,0 @@
package provider
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3"
)
type MigrationCopy struct {
Version int64
Source string // path to .sql script or go file
Registered bool
UpFnContext, DownFnContext func(context.Context, *sql.Tx) error
UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
}
var registeredGoMigrations = make(map[int64]*MigrationCopy)
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
// the migrations are nil or if a migration with the same version has already been registered.
//
// Not safe for concurrent use.
func SetGlobalGoMigrations(migrations []*MigrationCopy) error {
for _, m := range migrations {
if m == nil {
return errors.New("cannot register nil go migration")
}
if m.Version < 1 {
return errors.New("migration versions must be greater than zero")
}
if !m.Registered {
return errors.New("migration must be registered")
}
if m.Source != "" {
// If the source is set, expect it to be a file path with a numeric component that
// matches the version.
version, err := goose.NumericComponent(m.Source)
if err != nil {
return err
}
if version != m.Version {
return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source)
}
}
// It's valid for all of these to be nil.
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
registeredGoMigrations[m.Version] = m
}
return nil
}
// ResetGlobalGoMigrations resets the global go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalGoMigrations() {
registeredGoMigrations = make(map[int64]*MigrationCopy)
}

View File

@ -1,272 +0,0 @@
package provider
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"math"
"sync"
"github.com/pressly/goose/v3/database"
)
// Provider is a goose migration provider.
type Provider struct {
// mu protects all accesses to the provider and must be held when calling operations on the
// database.
mu sync.Mutex
db *sql.DB
fsys fs.FS
cfg config
store database.Store
// migrations are ordered by version in ascending order.
migrations []*migration
}
// NewProvider returns a new goose Provider.
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
// example, if the database dialect is "postgres", the database/sql driver could be
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
// as using a custom table name or supplying a custom store implementation, see [WithStore].
//
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
// migrations using [fs.Sub].
//
// See [ProviderOption] for more information on configuring the provider.
//
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
}
if fsys == nil {
fsys = noopFS{}
}
cfg := config{
registered: make(map[int64]*goMigration),
excludes: make(map[string]bool),
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
return nil, err
}
}
// Allow users to specify a custom store implementation, but only if they don't specify a
// dialect. If they specify a dialect, we'll use the default store implementation.
if dialect == "" && cfg.store == nil {
return nil, errors.New("dialect must not be empty")
}
if dialect != "" && cfg.store != nil {
return nil, errors.New("cannot set both dialect and custom store")
}
var store database.Store
if dialect != "" {
var err error
store, err = database.NewStore(dialect, DefaultTablename)
if err != nil {
return nil, err
}
} else {
store = cfg.store
}
if store.Tablename() == "" {
return nil, errors.New("invalid store implementation: table name must not be empty")
}
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
}
func newProvider(
db *sql.DB,
store database.Store,
fsys fs.FS,
cfg config,
global map[int64]*MigrationCopy,
) (*Provider, error) {
// Collect migrations from the filesystem and merge with registered migrations.
//
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
// lazily.
//
// TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to
// return an error if there are any SQL parsing errors. This adds a bit overhead to startup
// though, so we should make it optional.
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludes)
if err != nil {
return nil, err
}
registered := make(map[int64]*goMigration)
// Add user-registered Go migrations.
for version, m := range cfg.registered {
registered[version] = newGoMigration("", m.up, m.down)
}
// Add init() functions. This is a bit ugly because we need to convert from the old Migration
// struct to the new goMigration struct.
for version, m := range global {
if _, ok := registered[version]; ok {
return nil, fmt.Errorf("go migration with version %d already registered", version)
}
if m == nil {
return nil, errors.New("registered migration with nil init function")
}
g := newGoMigration(m.Source, nil, nil)
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext")
}
// Up
if m.UpFnContext != nil {
g.up = &GoMigrationFunc{
Run: m.UpFnContext,
}
} else if m.UpFnNoTxContext != nil {
g.up = &GoMigrationFunc{
RunNoTx: m.UpFnNoTxContext,
}
}
// Down
if m.DownFnContext != nil {
g.down = &GoMigrationFunc{
Run: m.DownFnContext,
}
} else if m.DownFnNoTxContext != nil {
g.down = &GoMigrationFunc{
RunNoTx: m.DownFnNoTxContext,
}
}
registered[version] = g
}
migrations, err := merge(filesystemSources, registered)
if err != nil {
return nil, err
}
if len(migrations) == 0 {
return nil, ErrNoMigrations
}
return &Provider{
db: db,
fsys: fsys,
cfg: cfg,
store: store,
migrations: migrations,
}, nil
}
// Status returns the status of all migrations, merging the list of migrations from the database and
// filesystem. The returned items are ordered by version, in ascending order.
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
return p.status(ctx)
}
// GetDBVersion returns the max version from the database, regardless of the applied order. For
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
// applied, it returns 0.
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
return p.getDBVersion(ctx)
}
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
// ascending order by version.
func (p *Provider) ListSources() []Source {
sources := make([]Source, 0, len(p.migrations))
for _, m := range p.migrations {
sources = append(sources, m.Source)
}
return sources
}
// Ping attempts to ping the database to verify a connection is available.
func (p *Provider) Ping(ctx context.Context) error {
return p.db.PingContext(ctx)
}
// Close closes the database connection.
func (p *Provider) Close() error {
return p.db.Close()
}
// ApplyVersion applies exactly one migration by version. If there is no source for the specified
// version, this method returns [ErrVersionNotFound]. If the migration has been applied already,
// this method returns [ErrAlreadyApplied].
//
// When direction is true, the up migration is executed, and when direction is false, the down
// migration is executed.
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.apply(ctx, version, direction)
}
// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method
// returns empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
return p.up(ctx, false, math.MaxInt64)
}
// UpByOne applies the next available migration. If there are no migrations to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
res, err := p.up(ctx, true, math.MaxInt64)
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, ErrNoNextVersion
}
// This should never happen. We should always have exactly one result and test for this.
if len(res) > 1 {
return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res))
}
return res[0], nil
}
// UpTo applies all available migrations up to, and including, the specified version. If there are
// no migrations to apply, this method returns empty list and nil error.
//
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.up(ctx, false, version)
}
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
// method returns [ErrNoNextVersion].
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
res, err := p.down(ctx, true, 0)
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, ErrNoNextVersion
}
if len(res) > 1 {
return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res))
}
return res[0], nil
}
// DownTo rolls back all migrations down to, but not including, the specified version.
//
// For instance, if the current database version is 11,10,9... and the requested version is 9, only
// migrations 11, 10 will be rolled back.
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 0 {
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
}
return p.down(ctx, false, version)
}

View File

@ -1,153 +0,0 @@
package provider_test
import (
"context"
"database/sql"
"errors"
"io/fs"
"path/filepath"
"testing"
"testing/fstest"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
_ "modernc.org/sqlite"
)
func TestProvider(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
t.Run("empty", func(t *testing.T) {
_, err := provider.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
})
mapFS := fstest.MapFS{
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys)
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
check.Equal(t, sources[0], newSource(provider.TypeSQL, "001_foo.sql", 1))
check.Equal(t, sources[1], newSource(provider.TypeSQL, "002_bar.sql", 2))
t.Run("duplicate_go", func(t *testing.T) {
// Not parallel because it modifies global state.
register := []*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: nil,
DownFnContext: nil,
},
}
err := provider.SetGlobalGoMigrations(register)
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(1, nil, nil),
)
check.HasError(t, err)
check.Equal(t, err.Error(), "go migration with version 1 already registered")
})
t.Run("empty_go", func(t *testing.T) {
db := newDB(t)
// explicit
_, err := provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(1, &provider.GoMigrationFunc{Run: nil}, &provider.GoMigrationFunc{Run: nil}),
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 must have an up function")
})
t.Run("duplicate_up", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
})
t.Run("duplicate_down", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
})
t.Run("not_registered", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go",
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration must be registered")
})
t.Run("zero_not_allowed", func(t *testing.T) {
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
{
Version: 0,
},
})
t.Cleanup(provider.ResetGlobalGoMigrations)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration versions must be greater than zero")
})
}
var (
migration1 = `
-- +goose Up
CREATE TABLE foo (id INTEGER PRIMARY KEY);
-- +goose Down
DROP TABLE foo;
`
migration2 = `
-- +goose Up
ALTER TABLE foo ADD COLUMN name TEXT;
-- +goose Down
ALTER TABLE foo DROP COLUMN name;
`
migration3 = `
-- +goose Up
CREATE TABLE bar (
id INTEGER PRIMARY KEY,
description TEXT
);
-- +goose Down
DROP TABLE bar;
`
migration4 = `
-- +goose Up
-- Rename the 'foo' table to 'my_foo'
ALTER TABLE foo RENAME TO my_foo;
-- Add a new column 'timestamp' to 'my_foo'
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
-- +goose Down
-- Remove the 'timestamp' column from 'my_foo'
ALTER TABLE my_foo DROP COLUMN timestamp;
-- Rename the 'my_foo' table back to 'foo'
ALTER TABLE my_foo RENAME TO foo;
`
)

View File

@ -1,516 +0,0 @@
package provider
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"time"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
var (
errMissingZeroVersion = errors.New("missing zero version migration")
)
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
var apply []*migration
if p.cfg.disableVersioning {
apply = p.migrations
} else {
// optimize(mf): Listing all migrations from the database isn't great. This is only required
// to support the allow missing (out-of-order) feature. For users that don't use this
// feature, we could just query the database for the current max version and then apply
// migrations greater than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
apply, err = p.resolveUpMigrations(dbMigrations, version)
if err != nil {
return nil, err
}
}
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own.
// https://github.com/pressly/goose/issues/222
//
// Careful, we can't use a single transaction for all migrations because some may have to be run
// in their own transaction.
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
}
func (p *Provider) resolveUpMigrations(
dbVersions []*database.ListMigrationsResult,
version int64,
) ([]*migration, error) {
var apply []*migration
var dbMaxVersion int64
// dbAppliedVersions is a map of all applied migrations in the database.
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
for _, m := range dbVersions {
dbAppliedVersions[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
// migrations entirely. At the moment this is not supported, but leaving this comment because
// that's where that logic would be handled.
//
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
// sure if this is a common use case, but it's possible.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, v.filename)
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
apply = append(apply, m)
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max). Skip any migrations that have already been applied.
for _, m := range p.migrations {
if dbAppliedVersions[m.Source.Version] {
continue
}
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
apply = append(apply, m)
}
}
return apply, nil
}
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.disableVersioning {
downMigrations := p.migrations
if downByOne {
last := p.migrations[len(p.migrations)-1]
downMigrations = []*migration{last}
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
if dbMigrations[0].Version == 0 {
return nil, nil
}
var downMigrations []*migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
downMigrations = append(downMigrations, m)
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
// runMigrations runs migrations sequentially in the given direction.
//
// If the migrations list is empty, return nil without error.
func (p *Provider) runMigrations(
ctx context.Context,
conn *sql.Conn,
migrations []*migration,
direction sqlparser.Direction,
byOne bool,
) ([]*MigrationResult, error) {
if len(migrations) == 0 {
return nil, nil
}
apply := migrations
if byOne {
apply = migrations[:1]
}
// Lazily parse SQL migrations (if any) in both directions. We do this before running any
// migrations so that we can fail fast if there are any errors and avoid leaving the database in
// a partially migrated state.
if err := parseSQL(p.fsys, false, apply); err != nil {
return nil, err
}
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
// to run in a transaction.
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but
// are locking the database with *sql.Conn. If the caller sets max open connections to 1, then
// this will deadlock because the Go migration will try to acquire a connection from the pool,
// but the pool is locked.
//
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to use
// *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of
// an edge case.
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
for _, m := range apply {
switch m.Source.Type {
case TypeGo:
if m.Go != nil && m.useTx(direction.ToBool()) {
return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1")
}
}
}
}
// Avoid allocating a slice because we may have a partial migration error.
// 1. Avoid giving the impression that N migrations were applied when in fact some were not
// 2. Avoid the caller having to check for nil results
var results []*MigrationResult
for _, m := range apply {
current := &MigrationResult{
Source: m.Source,
Direction: direction.String(),
Empty: m.isEmpty(direction.ToBool()),
}
start := time.Now()
if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil {
// TODO(mf): we should also return the pending migrations here, the remaining items in
// the apply slice.
current.Error = err
current.Duration = time.Since(start)
return nil, &PartialError{
Applied: results,
Failed: current,
Err: err,
}
}
current.Duration = time.Since(start)
results = append(results, current)
}
return results, nil
}
func (p *Provider) runIndividually(
ctx context.Context,
conn *sql.Conn,
direction bool,
m *migration,
) error {
if m.useTx(direction) {
// Run the migration in a transaction.
return p.beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := m.run(ctx, tx, direction); err != nil {
return err
}
if p.cfg.disableVersioning {
return nil
}
if direction {
return p.store.Insert(ctx, tx, database.InsertRequest{Version: m.Source.Version})
}
return p.store.Delete(ctx, tx, m.Source.Version)
})
}
// Run the migration outside of a transaction.
switch m.Source.Type {
case TypeGo:
// Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the
// GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open
// connections to 1. See the comment in runMigrations for more details.
if err := m.runNoTx(ctx, p.db, direction); err != nil {
return err
}
case TypeSQL:
if err := m.runConn(ctx, conn, direction); err != nil {
return err
}
}
if p.cfg.disableVersioning {
return nil
}
if direction {
return p.store.Insert(ctx, conn, database.InsertRequest{Version: m.Source.Version})
}
return p.store.Delete(ctx, conn, m.Source.Version)
}
// beginTx begins a transaction and runs the given function. If the function returns an error, the
// transaction is rolled back. Otherwise, the transaction is committed.
//
// If the provider is configured to use versioning, this function also inserts or deletes the
// migration version.
func (p *Provider) beginTx(
ctx context.Context,
conn *sql.Conn,
fn func(tx *sql.Tx) error,
) (retErr error) {
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if retErr != nil {
retErr = multierr.Append(retErr, tx.Rollback())
}
}()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
p.mu.Lock()
conn, err := p.db.Conn(ctx)
if err != nil {
p.mu.Unlock()
return nil, nil, err
}
// cleanup is a function that cleans up the connection, and optionally, the session lock.
cleanup := func() error {
p.mu.Unlock()
return conn.Close()
}
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
if err := l.SessionLock(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
}
cleanup = func() error {
p.mu.Unlock()
// Use a detached context to unlock the session. This is because the context passed to
// SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf):
// use [context.WithoutCancel] added in go1.21
detachedCtx := context.Background()
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
}
}
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
// need the version table because there is no versioning.
if !p.cfg.disableVersioning {
if err := p.ensureVersionTable(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
}
}
return conn, cleanup, nil
}
// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
// will not be parsed again.
//
// Important: This function will mutate SQL migrations and is not safe for concurrent use.
func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
for _, m := range migrations {
// If the migration is a SQL migration, and it has not been parsed, parse it.
if m.Source.Type == TypeSQL && m.SQL == nil {
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug)
if err != nil {
return err
}
m.SQL = &sqlMigration{
UseTx: parsed.UseTx,
UpStatements: parsed.Up,
DownStatements: parsed.Down,
}
}
}
return nil
}
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
}
return p.beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
}
if p.cfg.disableVersioning {
return nil
}
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
})
}
type missingMigration struct {
versionID int64
filename string
}
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
// migration is one that has a version less than the max version in the database.
func checkMissingMigrations(
dbMigrations []*database.ListMigrationsResult,
fsMigrations []*migration,
) []missingMigration {
existing := make(map[int64]bool)
var dbMaxVersion int64
for _, m := range dbMigrations {
existing[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
var missing []missingMigration
for _, m := range fsMigrations {
version := m.Source.Version
if !existing[version] && version < dbMaxVersion {
missing = append(missing, missingMigration{
versionID: version,
filename: m.filename(),
})
}
}
sort.Slice(missing, func(i, j int) bool {
return missing[i].versionID < missing[j].versionID
})
return missing
}
// getMigration returns the migration with the given version. If no migration is found, then
// ErrVersionNotFound is returned.
func (p *Provider) getMigration(version int64) (*migration, error) {
for _, m := range p.migrations {
if m.Source.Version == version {
return m, nil
}
}
return nil, ErrVersionNotFound
}
func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
m, err := p.getMigration(version)
if err != nil {
return nil, err
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
result, err := p.store.GetMigration(ctx, conn, version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
// If the migration has already been applied, return an error, unless the migration is being
// applied in the opposite direction. In that case, we allow the migration to be applied again.
if result != nil && direction {
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
}
d := sqlparser.DirectionDown
if direction {
d = sqlparser.DirectionUp
}
results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
}
return results[0], nil
}
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
// TODO(mf): add support for limit and order. Also would be nice to refactor the list query to
// support limiting the set.
status := make([]*MigrationStatus, 0, len(p.migrations))
for _, m := range p.migrations {
migrationStatus := &MigrationStatus{
Source: m.Source,
State: StatePending,
}
dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
if dbResult != nil {
migrationStatus.State = StateApplied
migrationStatus.AppliedAt = dbResult.Timestamp
}
status = append(status, migrationStatus)
}
return status, nil
}
func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return 0, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
res, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return 0, err
}
if len(res) == 0 {
return 0, nil
}
sort.Slice(res, func(i, j int) bool {
return res[i].Version > res[j].Version
})
return res[0].Version, nil
}

View File

@ -13,17 +13,25 @@ import (
// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive
// session-level advisory lock mechanism.
//
// This function creates a SessionLocker that can be used to acquire and release locks for
// This function creates a SessionLocker that can be used to acquire and release a lock for
// synchronization purposes. The lock acquisition is retried until it is successfully acquired or
// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the
// until the failure threshold is reached. The default lock duration is set to 5 minutes, and the
// default unlock duration is set to 1 minute.
//
// If you have long running migrations, you may want to increase the lock duration.
//
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
cfg := sessionLockerConfig{
lockID: DefaultLockID,
lockTimeout: DefaultLockTimeout,
unlockTimeout: DefaultUnlockTimeout,
lockID: DefaultLockID,
lockProbe: probe{
periodSeconds: 5 * time.Second,
failureThreshold: 60,
},
unlockProbe: probe{
periodSeconds: 2 * time.Second,
failureThreshold: 30,
},
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
@ -32,13 +40,13 @@ func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error
}
return &postgresSessionLocker{
lockID: cfg.lockID,
retryLock: retry.WithMaxDuration(
cfg.lockTimeout,
retry.NewConstant(2*time.Second),
retryLock: retry.WithMaxRetries(
cfg.lockProbe.failureThreshold,
retry.NewConstant(cfg.lockProbe.periodSeconds),
),
retryUnlock: retry.WithMaxDuration(
cfg.unlockTimeout,
retry.NewConstant(2*time.Second),
retryUnlock: retry.WithMaxRetries(
cfg.unlockProbe.failureThreshold,
retry.NewConstant(cfg.unlockProbe.periodSeconds),
),
}, nil
}

View File

@ -6,7 +6,6 @@ import (
"errors"
"sync"
"testing"
"time"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
@ -30,8 +29,8 @@ func TestPostgresSessionLocker(t *testing.T) {
)
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
lock.WithLockTimeout(1, 4), // 4 second timeout
lock.WithUnlockTimeout(1, 4), // 4 second timeout
)
check.NoError(t, err)
ctx := context.Background()
@ -60,8 +59,8 @@ func TestPostgresSessionLocker(t *testing.T) {
})
t.Run("lock_close_conn_unlock", func(t *testing.T) {
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
lock.WithLockTimeout(1, 4), // 4 second timeout
lock.WithUnlockTimeout(1, 4), // 4 second timeout
)
check.NoError(t, err)
ctx := context.Background()
@ -103,10 +102,12 @@ func TestPostgresSessionLocker(t *testing.T) {
// Exactly one connection should acquire the lock. While the other connections
// should fail to acquire the lock and timeout.
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
lock.WithLockTimeout(1, 4), // 4 second timeout
lock.WithUnlockTimeout(1, 4), // 4 second timeout
)
check.NoError(t, err)
// NOTE, we are not unlocking the lock, because we want to test that the lock is
// released when the connection is closed.
ch <- locker.SessionLock(ctx, conn)
}()
}
@ -138,8 +139,8 @@ func TestPostgresSessionLocker(t *testing.T) {
)
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
lock.WithLockTimeout(1, 4), // 4 second timeout
lock.WithUnlockTimeout(1, 4), // 4 second timeout
)
check.NoError(t, err)
@ -179,6 +180,7 @@ func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
if err != nil {
return nil, err
}
defer rows.Close()
var pgLocks []pgLock
for rows.Next() {
var p pgLock

View File

@ -1,6 +1,7 @@
package lock
import (
"errors"
"time"
)
@ -10,11 +11,6 @@ const (
//
// crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA))
DefaultLockID int64 = 5887940537704921958
// Default values for the lock (time to wait for the lock to be acquired) and unlock (time to
// wait for the lock to be released) wait durations.
DefaultLockTimeout time.Duration = 60 * time.Minute
DefaultUnlockTimeout time.Duration = 1 * time.Minute
)
// SessionLockerOption is used to configure a SessionLocker.
@ -32,26 +28,65 @@ func WithLockID(lockID int64) SessionLockerOption {
})
}
// WithLockTimeout sets the max duration to wait for the lock to be acquired.
func WithLockTimeout(duration time.Duration) SessionLockerOption {
// WithLockTimeout sets the max duration to wait for the lock to be acquired. The total duration
// will be the period times the failure threshold.
//
// By default, the lock timeout is 300s (5min), where the lock is retried every 5 seconds (period)
// up to 60 times (failure threshold).
//
// The minimum period is 1 second, and the minimum failure threshold is 1.
func WithLockTimeout(period, failureThreshold uint64) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.lockTimeout = duration
if period < 1 {
return errors.New("period must be greater than 0, minimum is 1")
}
if failureThreshold < 1 {
return errors.New("failure threshold must be greater than 0, minimum is 1")
}
c.lockProbe = probe{
periodSeconds: time.Duration(period) * time.Second,
failureThreshold: failureThreshold,
}
return nil
})
}
// WithUnlockTimeout sets the max duration to wait for the lock to be released.
func WithUnlockTimeout(duration time.Duration) SessionLockerOption {
// WithUnlockTimeout sets the max duration to wait for the lock to be released. The total duration
// will be the period times the failure threshold.
//
// By default, the lock timeout is 60s, where the lock is retried every 2 seconds (period) up to 30
// times (failure threshold).
//
// The minimum period is 1 second, and the minimum failure threshold is 1.
func WithUnlockTimeout(period, failureThreshold uint64) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.unlockTimeout = duration
if period < 1 {
return errors.New("period must be greater than 0, minimum is 1")
}
if failureThreshold < 1 {
return errors.New("failure threshold must be greater than 0, minimum is 1")
}
c.unlockProbe = probe{
periodSeconds: time.Duration(period) * time.Second,
failureThreshold: failureThreshold,
}
return nil
})
}
type sessionLockerConfig struct {
lockID int64
lockTimeout time.Duration
unlockTimeout time.Duration
lockID int64
lockProbe probe
unlockProbe probe
}
// probe is used to configure how often and how many times to retry a lock or unlock operation. The
// total timeout will be the period times the failure threshold.
type probe struct {
// How often (in seconds) to perform the probe.
periodSeconds time.Duration
// Number of times to retry the probe.
failureThreshold uint64
}
var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)

View File

@ -18,22 +18,36 @@ import (
// Both up and down functions may be nil, in which case the migration will be recorded in the
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
// a version without running any functions. See [GoFunc] for more details.
func NewGoMigration(version int64, up, down *GoFunc) Migration {
m := Migration{
func NewGoMigration(version int64, up, down *GoFunc) *Migration {
m := &Migration{
Type: TypeGo,
Registered: true,
Version: version,
Next: -1, Previous: -1,
goUp: up,
goDown: down,
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
construct: true,
}
updateMode := func(f *GoFunc) *GoFunc {
// infer mode from function
if f.Mode == 0 {
if f.RunTx != nil && f.RunDB == nil {
f.Mode = TransactionEnabled
}
if f.RunTx == nil && f.RunDB != nil {
f.Mode = TransactionDisabled
}
}
return f
}
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
// we will remove these fields in favor of [GoFunc].
//
// Note, this function does not do any validation. Validation is lazily done when the migration
// is registered.
if up != nil {
m.goUp = updateMode(up)
if up.RunDB != nil {
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
@ -45,6 +59,8 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration {
}
}
if down != nil {
m.goDown = updateMode(down)
if down.RunDB != nil {
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
@ -55,12 +71,6 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration {
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
}
}
if m.goUp == nil {
m.goUp = &GoFunc{Mode: TransactionEnabled}
}
if m.goDown == nil {
m.goDown = &GoFunc{Mode: TransactionEnabled}
}
return m
}
@ -76,10 +86,6 @@ type Migration struct {
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
// These fields are used internally by goose and users are not expected to set them. Instead,
// use [NewGoMigration] to create a new go migration.
construct bool
goUp, goDown *GoFunc
// These fields will be removed in a future major version. They are here for backwards
// compatibility and are an implementation detail.
@ -98,6 +104,26 @@ type Migration struct {
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
noVersioning bool
// These fields are used internally by goose and users are not expected to set them. Instead,
// use [NewGoMigration] to create a new go migration.
construct bool
goUp, goDown *GoFunc
sql sqlMigration
}
type sqlMigration struct {
// The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
// optimization to avoid parsing migrations that may never be needed. Typically, migrations are
// incremental, and users often run only the most recent ones, making parsing of prior
// migrations unnecessary in most cases.
Parsed bool
// Parsed must be set to true before the following fields are used.
UseTx bool
Up []string
Down []string
}
// GoFunc represents a Go migration function.

View File

@ -18,3 +18,11 @@ func (osFS) Stat(name string) (fs.FileInfo, error) { return os.Stat(filepath.Fro
func (osFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(filepath.FromSlash(name)) }
func (osFS) Glob(pattern string) ([]string, error) { return filepath.Glob(filepath.FromSlash(pattern)) }
type noopFS struct{}
var _ fs.FS = noopFS{}
func (f noopFS) Open(name string) (fs.File, error) {
return nil, os.ErrNotExist
}

477
provider.go Normal file
View File

@ -0,0 +1,477 @@
package goose
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"math"
"sort"
"strconv"
"strings"
"sync"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
// Provider is a goose migration provider.
type Provider struct {
// mu protects all accesses to the provider and must be held when calling operations on the
// database.
mu sync.Mutex
db *sql.DB
store database.Store
fsys fs.FS
cfg config
// migrations are ordered by version in ascending order.
migrations []*Migration
}
// NewProvider returns a new goose provider.
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
// example, if the database dialect is "postgres", the database/sql driver could be
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
// as using a custom table name or supplying a custom store implementation, see [WithStore].
//
// fsys is the filesystem used to read migration files, but may be nil. Most users will want to use
// [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
// migrations using [fs.Sub].
//
// See [ProviderOption] for more information on configuring the provider.
//
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
}
if fsys == nil {
fsys = noopFS{}
}
cfg := config{
registered: make(map[int64]*Migration),
excludePaths: make(map[string]bool),
excludeVersions: make(map[int64]bool),
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
return nil, err
}
}
// Allow users to specify a custom store implementation, but only if they don't specify a
// dialect. If they specify a dialect, we'll use the default store implementation.
if dialect == "" && cfg.store == nil {
return nil, errors.New("dialect must not be empty")
}
if dialect != "" && cfg.store != nil {
return nil, errors.New("dialect must be empty when using a custom store implementation")
}
var store database.Store
if dialect != "" {
var err error
store, err = database.NewStore(dialect, DefaultTablename)
if err != nil {
return nil, err
}
} else {
store = cfg.store
}
if store.Tablename() == "" {
return nil, errors.New("invalid store implementation: table name must not be empty")
}
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
}
func newProvider(
db *sql.DB,
store database.Store,
fsys fs.FS,
cfg config,
global map[int64]*Migration,
) (*Provider, error) {
// Collect migrations from the filesystem and merge with registered migrations.
//
// Note, we don't parse SQL migrations here. They are parsed lazily when required.
// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
// we should make it optional.
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
if err != nil {
return nil, err
}
versionToGoMigration := make(map[int64]*Migration)
// Add user-registered Go migrations from the provider.
for version, m := range cfg.registered {
versionToGoMigration[version] = m
}
// Add globally registered Go migrations.
for version, m := range global {
if _, ok := versionToGoMigration[version]; ok {
return nil, fmt.Errorf("global go migration with version %d previously registered with provider", version)
}
versionToGoMigration[version] = m
}
// At this point we have all registered unique Go migrations (if any). We need to merge them
// with SQL migrations from the filesystem.
migrations, err := merge(filesystemSources, versionToGoMigration)
if err != nil {
return nil, err
}
if len(migrations) == 0 {
return nil, ErrNoMigrations
}
return &Provider{
db: db,
fsys: fsys,
cfg: cfg,
store: store,
migrations: migrations,
}, nil
}
// Status returns the status of all migrations, merging the list of migrations from the database and
// filesystem. The returned items are ordered by version, in ascending order.
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
return p.status(ctx)
}
// GetDBVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
// this method returns 4. If no migrations have been applied, it returns 0.
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
return p.getDBMaxVersion(ctx)
}
// ListSources returns a list of all migration sources known to the provider, sorted in ascending
// order by version. The path field may be empty for manually registered migrations, such as Go
// migrations registered using the [WithGoMigrations] option.
func (p *Provider) ListSources() []*Source {
sources := make([]*Source, 0, len(p.migrations))
for _, m := range p.migrations {
sources = append(sources, &Source{
Type: m.Type,
Path: m.Source,
Version: m.Version,
})
}
return sources
}
// Ping attempts to ping the database to verify a connection is available.
func (p *Provider) Ping(ctx context.Context) error {
return p.db.PingContext(ctx)
}
// Close closes the database connection initially supplied to the provider.
func (p *Provider) Close() error {
return p.db.Close()
}
// ApplyVersion applies exactly one migration for the specified version. If there is no migration
// available for the specified version, this method returns [ErrVersionNotFound]. If the migration
// has already been applied, this method returns [ErrAlreadyApplied].
//
// The direction parameter determines the migration direction: true for up migration and false for
// down migration.
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
res, err := p.apply(ctx, version, direction)
if err != nil {
return nil, err
}
// This should never happen, we must return exactly one result.
if len(res) != 1 {
versions := make([]string, 0, len(res))
for _, r := range res {
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
}
return nil, fmt.Errorf(
"unexpected number of migrations applied running apply, expecting exactly one result: %v",
strings.Join(versions, ","),
)
}
return res[0], nil
}
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
// empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
return p.up(ctx, false, math.MaxInt64)
}
// UpByOne applies the next pending migration. If there is no next migration to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
res, err := p.up(ctx, true, math.MaxInt64)
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, ErrNoNextVersion
}
// This should never happen, we must return exactly one result.
if len(res) != 1 {
versions := make([]string, 0, len(res))
for _, r := range res {
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
}
return nil, fmt.Errorf(
"unexpected number of migrations applied running up-by-one, expecting exactly one result: %v",
strings.Join(versions, ","),
)
}
return res[0], nil
}
// UpTo applies all pending migrations up to, and including, the specified version. If there are no
// migrations to apply, this method returns empty list and nil error.
//
// For example, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
return p.up(ctx, false, version)
}
// Down rolls back the most recently applied migration. If there are no migrations to rollback, this
// method returns [ErrNoNextVersion].
//
// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
// the migration version. This only applies in scenarios where migrations are allowed to be applied
// out of order.
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
res, err := p.down(ctx, true, 0)
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, ErrNoNextVersion
}
// This should never happen, we must return exactly one result.
if len(res) != 1 {
versions := make([]string, 0, len(res))
for _, r := range res {
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
}
return nil, fmt.Errorf(
"unexpected number of migrations applied running down, expecting exactly one result: %v",
strings.Join(versions, ","),
)
}
return res[0], nil
}
// DownTo rolls back all migrations down to, but not including, the specified version.
//
// For example, if the current database version is 11,10,9... and the requested version is 9, only
// migrations 11, 10 will be rolled back.
//
// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
// the migration version. This only applies in scenarios where migrations are allowed to be applied
// out of order.
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 0 {
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
}
return p.down(ctx, false, version)
}
// *** Internal methods ***
func (p *Provider) up(
ctx context.Context,
byOne bool,
version int64,
) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errInvalidVersion
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
var apply []*Migration
if p.cfg.disableVersioning {
if byOne {
return nil, errors.New("up-by-one not supported when versioning is disabled")
}
apply = p.migrations
} else {
// optimize(mf): Listing all migrations from the database isn't great. This is only required
// to support the allow missing (out-of-order) feature. For users that don't use this
// feature, we could just query the database for the current max version and then apply
// migrations greater than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
apply, err = p.resolveUpMigrations(dbMigrations, version)
if err != nil {
return nil, err
}
}
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne)
}
func (p *Provider) down(
ctx context.Context,
byOne bool,
version int64,
) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.disableVersioning {
var downMigrations []*Migration
if byOne {
last := p.migrations[len(p.migrations)-1]
downMigrations = []*Migration{last}
} else {
downMigrations = p.migrations
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
// We never migrate the zero version down.
if dbMigrations[0].Version == 0 {
return nil, nil
}
var apply []*Migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
apply = append(apply, m)
}
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne)
}
func (p *Provider) apply(
ctx context.Context,
version int64,
direction bool,
) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errInvalidVersion
}
m, err := p.getMigration(version)
if err != nil {
return nil, err
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
result, err := p.store.GetMigration(ctx, conn, version)
if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
return nil, err
}
// If the migration has already been applied, return an error. But, if the migration is being
// rolled back, we allow the individual migration to be applied again.
if result != nil && direction {
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
}
d := sqlparser.DirectionDown
if direction {
d = sqlparser.DirectionUp
}
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
}
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
status := make([]*MigrationStatus, 0, len(p.migrations))
for _, m := range p.migrations {
migrationStatus := &MigrationStatus{
Source: &Source{
Type: m.Type,
Path: m.Source,
Version: m.Version,
},
State: StatePending,
}
dbResult, err := p.store.GetMigration(ctx, conn, m.Version)
if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
return nil, err
}
if dbResult != nil {
migrationStatus.State = StateApplied
migrationStatus.AppliedAt = dbResult.Timestamp
}
status = append(status, migrationStatus)
}
return status, nil
}
func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return 0, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
res, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return 0, err
}
if len(res) == 0 {
return 0, errMissingZeroVersion
}
// Sort in descending order.
sort.Slice(res, func(i, j int) bool {
return res[i].Version > res[j].Version
})
// Return the highest version.
return res[0].Version, nil
}

View File

@ -1,15 +1,12 @@
package provider
package goose
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"github.com/pressly/goose/v3"
)
// fileSources represents a collection of migration files on the filesystem.
@ -18,25 +15,6 @@ type fileSources struct {
goSources []Source
}
// TODO(mf): remove?
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
switch t {
case TypeGo:
for _, source := range s.goSources {
if source.Version == version {
return &source
}
}
case TypeSQL:
for _, source := range s.sqlSources {
if source.Version == version {
return &source
}
}
}
return nil
}
// collectFilesystemSources scans the file system for migration files that have a numeric prefix
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
// be nil, in which case an empty fileSources is returned.
@ -46,7 +24,12 @@ func (s *fileSources) lookup(t MigrationType, version int64) *Source {
//
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
// migration sources from the filesystem.
func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
func collectFilesystemSources(
fsys fs.FS,
strict bool,
excludePaths map[string]bool,
excludeVersions map[int64]bool,
) (*fileSources, error) {
if fsys == nil {
return new(fileSources), nil
}
@ -62,8 +45,11 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
}
for _, fullpath := range files {
base := filepath.Base(fullpath)
// Skip explicit excludes or Go test files.
if excludes[base] || strings.HasSuffix(base, "_test.go") {
if strings.HasSuffix(base, "_test.go") {
continue
}
if excludePaths[base] {
// TODO(mf): log this?
continue
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
@ -71,13 +57,17 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := goose.NumericComponent(base)
version, err := NumericComponent(base)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
continue
}
if excludeVersions[version] {
// TODO: log this?
continue
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
@ -101,7 +91,7 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
})
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("unknown migration type: %s", base)
return nil, fmt.Errorf("invalid file extension: %q", base)
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
@ -110,15 +100,25 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
return sources, nil
}
func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) {
var migrations []*migration
migrationLookup := make(map[int64]*migration)
func newSQLMigration(source Source) *Migration {
return &Migration{
Type: source.Type,
Version: source.Version,
Source: source.Path,
construct: true,
Next: -1, Previous: -1,
sql: sqlMigration{
Parsed: false, // SQL migrations are parsed lazily.
},
}
}
func merge(sources *fileSources, registerd map[int64]*Migration) ([]*Migration, error) {
var migrations []*Migration
migrationLookup := make(map[int64]*Migration)
// Add all SQL migrations to the list of migrations.
for _, source := range sources.sqlSources {
m := &migration{
Source: source,
SQL: nil, // SQL migrations are parsed lazily.
}
m := newSQLMigration(source)
migrations = append(migrations, m)
migrationLookup[source.Version] = m
}
@ -147,38 +147,24 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
// the SQL migration files.
for version, r := range registerd {
fullpath := r.fullpath
if fullpath == "" {
if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Path
}
}
// Ensure there are no duplicate versions.
if existing, ok := migrationLookup[version]; ok {
fullpath := r.fullpath
fullpath := r.Source
if fullpath == "" {
fullpath = "manually registered (no source)"
}
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing.Source.Path,
existing.Source,
fullpath,
)
}
m := &migration{
Source: Source{
Type: TypeGo,
Path: fullpath, // May be empty if migration was registered manually.
Version: version,
},
Go: r,
}
migrations = append(migrations, m)
migrationLookup[version] = m
migrations = append(migrations, r)
migrationLookup[version] = r
}
// Sort migrations by version in ascending order.
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Source.Version < migrations[j].Source.Version
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
@ -203,11 +189,3 @@ func unregisteredError(unregistered []string) error {
return errors.New(b.String())
}
type noopFS struct{}
var _ fs.FS = noopFS{}
func (f noopFS) Open(name string) (fs.File, error) {
return nil, os.ErrNotExist
}

View File

@ -1,4 +1,4 @@
package provider
package goose
import (
"io/fs"
@ -12,21 +12,21 @@ import (
func TestCollectFileSources(t *testing.T) {
t.Parallel()
t.Run("nil_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(nil, false, nil)
sources, err := collectFilesystemSources(nil, false, nil, nil)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("noop_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(noopFS{}, false, nil)
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("empty_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil)
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) {
"00000_foo.sql": sqlMapFile,
}
// strict disable - should not error
sources, err := collectFilesystemSources(mapFS, false, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
// strict enabled - should error
_, err = collectFilesystemSources(mapFS, true, nil)
_, err = collectFilesystemSources(mapFS, true, nil, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration version must be greater than zero")
})
t.Run("collect", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
@ -76,6 +76,7 @@ func TestCollectFileSources(t *testing.T) {
"00002_bar.sql": true,
"00110_qux.sql": true,
},
nil,
)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
@ -96,7 +97,7 @@ func TestCollectFileSources(t *testing.T) {
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
_, err = collectFilesystemSources(fsys, true, nil)
_, err = collectFilesystemSources(fsys, true, nil, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
})
@ -108,7 +109,7 @@ func TestCollectFileSources(t *testing.T) {
"4_qux.sql": sqlMapFile,
"5_foo_test.go": {Data: []byte(`package goose_test`)},
}
sources, err := collectFilesystemSources(mapFS, false, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
@ -123,7 +124,7 @@ func TestCollectFileSources(t *testing.T) {
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
}
sources, err := collectFilesystemSources(mapFS, false, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1)
@ -142,7 +143,7 @@ func TestCollectFileSources(t *testing.T) {
"001_foo.sql": sqlMapFile,
"01_bar.sql": sqlMapFile,
}
_, err := collectFilesystemSources(mapFS, false, nil)
_, err := collectFilesystemSources(mapFS, false, nil, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
})
@ -158,7 +159,7 @@ func TestCollectFileSources(t *testing.T) {
t.Helper()
f, err := fs.Sub(mapFS, dirpath)
check.NoError(t, err)
got, err := collectFilesystemSources(f, false, nil)
got, err := collectFilesystemSources(f, false, nil, nil)
check.NoError(t, err)
check.Number(t, len(got.sqlSources), len(sqlSources))
check.Number(t, len(got.goSources), 0)
@ -194,27 +195,21 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 1)
check.Equal(t, len(sources.goSources), 2)
src1 := sources.lookup(TypeSQL, 1)
check.Bool(t, src1 != nil, true)
src2 := sources.lookup(TypeGo, 2)
check.Bool(t, src2 != nil, true)
src3 := sources.lookup(TypeGo, 3)
check.Bool(t, src3 != nil, true)
t.Run("valid", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
2: newGoMigration("", nil, nil),
3: newGoMigration("", nil, nil),
})
registered := map[int64]*Migration{
2: NewGoMigration(2, nil, nil),
3: NewGoMigration(3, nil, nil),
}
migrations, err := merge(sources, registered)
check.NoError(t, err)
check.Number(t, len(migrations), 3)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
assertMigration(t, migrations[1], newSource(TypeGo, "", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
})
t.Run("unregistered_all", func(t *testing.T) {
_, err := merge(sources, nil)
@ -224,18 +219,16 @@ func TestMerge(t *testing.T) {
check.Contains(t, err.Error(), "00003_baz.go")
})
t.Run("unregistered_some", func(t *testing.T) {
_, err := merge(sources, map[int64]*goMigration{
2: newGoMigration("", nil, nil),
})
_, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)})
check.HasError(t, err)
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
check.Contains(t, err.Error(), "00003_baz.go")
})
t.Run("duplicate_sql", func(t *testing.T) {
_, err := merge(sources, map[int64]*goMigration{
1: newGoMigration("", nil, nil), // duplicate. SQL already exists.
2: newGoMigration("", nil, nil),
3: newGoMigration("", nil, nil),
_, err := merge(sources, map[int64]*Migration{
1: NewGoMigration(1, nil, nil), // duplicate. SQL already exists.
2: NewGoMigration(2, nil, nil),
3: NewGoMigration(3, nil, nil),
})
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
@ -250,13 +243,13 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
3: newGoMigration("", nil, nil),
migrations, err := merge(sources, map[int64]*Migration{
3: NewGoMigration(3, nil, nil),
// 4 is missing
6: newGoMigration("", nil, nil),
6: NewGoMigration(6, nil, nil),
})
check.NoError(t, err)
check.Number(t, len(migrations), 5)
@ -274,20 +267,20 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
migrations, err := merge(sources, map[int64]*Migration{
// This is the only Go file on disk.
2: newGoMigration("", nil, nil),
2: NewGoMigration(2, nil, nil),
// These are not on disk. Explicitly registered.
3: newGoMigration("", nil, nil),
6: newGoMigration("", nil, nil),
3: NewGoMigration(3, nil, nil),
6: NewGoMigration(6, nil, nil),
})
check.NoError(t, err)
check.Number(t, len(migrations), 4)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[1], newSource(TypeGo, "", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
})
@ -308,15 +301,15 @@ func TestCheckMissingMigrations(t *testing.T) {
{Version: 5},
{Version: 7}, // <-- database max version_id
}
fsMigrations := []*migration{
newMigrationVersion(1),
newMigrationVersion(2), // missing migration
newMigrationVersion(3),
newMigrationVersion(4),
newMigrationVersion(5),
newMigrationVersion(6), // missing migration
newMigrationVersion(7), // ----- database max version_id -----
newMigrationVersion(8), // new migration
fsMigrations := []*Migration{
newSQLMigration(Source{Version: 1}),
newSQLMigration(Source{Version: 2}), // missing migration
newSQLMigration(Source{Version: 3}),
newSQLMigration(Source{Version: 4}),
newSQLMigration(Source{Version: 5}),
newSQLMigration(Source{Version: 6}), // missing migration
newSQLMigration(Source{Version: 7}), // ----- database max version_id -----
newSQLMigration(Source{Version: 8}), // new migration
}
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
@ -334,9 +327,9 @@ func TestCheckMissingMigrations(t *testing.T) {
{Version: 5},
{Version: 2},
}
fsMigrations := []*migration{
newMigrationVersion(3), // new migration
newMigrationVersion(4), // new migration
fsMigrations := []*Migration{
NewGoMigration(3, nil, nil), // new migration
NewGoMigration(4, nil, nil), // new migration
}
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
@ -345,24 +338,19 @@ func TestCheckMissingMigrations(t *testing.T) {
})
}
func newMigrationVersion(version int64) *migration {
return &migration{
Source: Source{
Version: version,
},
}
}
func assertMigration(t *testing.T, got *migration, want Source) {
func assertMigration(t *testing.T, got *Migration, want Source) {
t.Helper()
check.Equal(t, got.Source, want)
switch got.Source.Type {
check.Equal(t, got.Type, want.Type)
check.Equal(t, got.Version, want.Version)
check.Equal(t, got.Source, want.Path)
switch got.Type {
case TypeGo:
check.Bool(t, got.Go != nil, true)
check.Bool(t, got.goUp != nil, true)
check.Bool(t, got.goDown != nil, true)
case TypeSQL:
check.Bool(t, got.SQL == nil, true)
check.Bool(t, got.sql.Parsed, false)
default:
t.Fatalf("unknown migration type: %s", got.Source.Type)
t.Fatalf("unknown migration type: %s", got.Type)
}
}

View File

@ -1,4 +1,4 @@
package provider
package goose
import (
"errors"
@ -16,8 +16,8 @@ var (
// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
ErrNoMigrations = errors.New("no migrations found")
// ErrNoNextVersion when the next migration version is not found.
ErrNoNextVersion = errors.New("no next version found")
// errInvalidVersion is returned when a migration version is invalid.
errInvalidVersion = errors.New("version must be greater than 0")
)
// PartialError is returned when a migration fails, but some migrations already got applied.

View File

@ -1,8 +1,6 @@
package provider
package goose
import (
"context"
"database/sql"
"errors"
"fmt"
@ -12,12 +10,11 @@ import (
const (
// DefaultTablename is the default name of the database table used to track history of applied
// migrations. It can be overridden using the [WithTableName] option when creating a new
// provider.
// migrations.
DefaultTablename = "goose_db_version"
)
// ProviderOption is a configuration option for a goose provider.
// ProviderOption is a configuration option for a goose goose.
type ProviderOption interface {
apply(*config) error
}
@ -84,84 +81,75 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
})
}
// WithExcludes excludes the given file names from the list of migrations.
//
// If WithExcludes is called multiple times, the list of excludes is merged.
func WithExcludes(excludes []string) ProviderOption {
// WithExcludeNames excludes the given file name from the list of migrations. If called multiple
// times, the list of excludes is merged.
func WithExcludeNames(excludes []string) ProviderOption {
return configFunc(func(c *config) error {
for _, name := range excludes {
c.excludes[name] = true
if _, ok := c.excludePaths[name]; ok {
return fmt.Errorf("duplicate exclude file name: %s", name)
}
c.excludePaths[name] = true
}
return nil
})
}
// GoMigrationFunc is a user-defined Go migration, registered using the option [WithGoMigration].
type GoMigrationFunc struct {
// One of the following must be set:
Run func(context.Context, *sql.Tx) error
// -- OR --
RunNoTx func(context.Context, *sql.DB) error
}
// WithGoMigration registers a Go migration with the given version.
//
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
// set.
func WithGoMigration(version int64, up, down *GoMigrationFunc) ProviderOption {
// WithExcludeVersions excludes the given versions from the list of migrations. If called multiple
// times, the list of excludes is merged.
func WithExcludeVersions(versions []int64) ProviderOption {
return configFunc(func(c *config) error {
if version < 1 {
return errors.New("version must be greater than zero")
}
if _, ok := c.registered[version]; ok {
return fmt.Errorf("go migration with version %d already registered", version)
}
// Allow nil up/down functions. This enables users to apply "no-op" migrations, while
// versioning them.
if up != nil {
if up.Run == nil && up.RunNoTx == nil {
return fmt.Errorf("go migration with version %d must have an up function", version)
for _, version := range versions {
if version < 1 {
return errInvalidVersion
}
if up.Run != nil && up.RunNoTx != nil {
return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version)
if _, ok := c.excludeVersions[version]; ok {
return fmt.Errorf("duplicate excludes version: %d", version)
}
}
if down != nil {
if down.Run == nil && down.RunNoTx == nil {
return fmt.Errorf("go migration with version %d must have a down function", version)
}
if down.Run != nil && down.RunNoTx != nil {
return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version)
}
}
c.registered[version] = &goMigration{
up: up,
down: down,
c.excludeVersions[version] = true
}
return nil
})
}
// WithAllowedMissing allows the provider to apply missing (out-of-order) migrations. By default,
// WithGoMigrations registers Go migrations with the provider. If a Go migration with the same
// version has already been registered, an error will be returned.
//
// Go migrations must be constructed using the [NewGoMigration] function.
func WithGoMigrations(migrations ...*Migration) ProviderOption {
return configFunc(func(c *config) error {
for _, m := range migrations {
if _, ok := c.registered[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
if err := checkGoMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
c.registered[m.Version] = m
}
return nil
})
}
// WithAllowOutofOrder allows the provider to apply missing (out-of-order) migrations. By default,
// goose will raise an error if it encounters a missing migration.
//
// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
// by new migrations.
func WithAllowedMissing(b bool) ProviderOption {
// For example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is
// true, then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order
// of applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first,
// followed by new migrations.
func WithAllowOutofOrder(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.allowMissing = b
return nil
})
}
// WithDisabledVersioning disables versioning. Disabling versioning allows applying migrations
// WithDisableVersioning disables versioning. Disabling versioning allows applying migrations
// without tracking the versions in the database schema table. Useful for tests, seeding a database
// or running ad-hoc queries. By default, goose will track all versions in the database schema
// table.
func WithDisabledVersioning(b bool) ProviderOption {
func WithDisableVersioning(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.disableVersioning = b
return nil
@ -171,12 +159,13 @@ func WithDisabledVersioning(b bool) ProviderOption {
type config struct {
store database.Store
verbose bool
excludes map[string]bool
verbose bool
excludePaths map[string]bool
excludeVersions map[int64]bool
// Go migrations registered by the user. These will be merged/resolved with migrations from the
// filesystem and init() functions.
registered map[int64]*goMigration
// Go migrations registered by the user. These will be merged/resolved against the globally
// registered migrations.
registered map[int64]*Migration
// Locking options
lockEnabled bool

View File

@ -1,4 +1,4 @@
package provider_test
package goose_test
import (
"database/sql"
@ -6,9 +6,9 @@ import (
"testing"
"testing/fstest"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
_ "modernc.org/sqlite"
)
@ -24,45 +24,42 @@ func TestNewProvider(t *testing.T) {
}
t.Run("invalid", func(t *testing.T) {
// Empty dialect not allowed
_, err = provider.NewProvider("", db, fsys)
_, err = goose.NewProvider("", db, fsys)
check.HasError(t, err)
// Invalid dialect not allowed
_, err = provider.NewProvider("unknown-dialect", db, fsys)
_, err = goose.NewProvider("unknown-dialect", db, fsys)
check.HasError(t, err)
// Nil db not allowed
_, err = provider.NewProvider(database.DialectSQLite3, nil, fsys)
check.HasError(t, err)
// Nil fsys not allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, nil)
_, err = goose.NewProvider(database.DialectSQLite3, nil, fsys)
check.HasError(t, err)
// Nil store not allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(nil))
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(nil))
check.HasError(t, err)
// Cannot set both dialect and store
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
check.NoError(t, err)
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(store))
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(store))
check.HasError(t, err)
// Multiple stores not allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithStore(store),
provider.WithStore(store),
_, err = goose.NewProvider(database.DialectSQLite3, db, nil,
goose.WithStore(store),
goose.WithStore(store),
)
check.HasError(t, err)
})
t.Run("valid", func(t *testing.T) {
// Valid dialect, db, and fsys allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys)
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and verbose allowed
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys,
provider.WithVerbose(testing.Verbose()),
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys,
goose.WithVerbose(testing.Verbose()),
)
check.NoError(t, err)
// Custom store allowed
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
check.NoError(t, err)
_, err = provider.NewProvider("", db, nil, provider.WithStore(store))
_, err = goose.NewProvider("", db, nil, goose.WithStore(store))
check.HasError(t, err)
})
}

460
provider_run.go Normal file
View File

@ -0,0 +1,460 @@
package goose
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"time"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
)
var (
errMissingZeroVersion = errors.New("missing zero version migration")
)
func (p *Provider) resolveUpMigrations(
dbVersions []*database.ListMigrationsResult,
version int64,
) ([]*Migration, error) {
var apply []*Migration
var dbMaxVersion int64
// dbAppliedVersions is a map of all applied migrations in the database.
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
for _, m := range dbVersions {
dbAppliedVersions[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
// migrations entirely. At the moment this is not supported, but leaving this comment because
// that's where that logic would be handled.
//
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
// sure if this is a common use case, but it's possible.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, fmt.Sprintf("%d", v.versionID))
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
apply = append(apply, m)
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max). Skip any migrations that have already been applied.
for _, m := range p.migrations {
if dbAppliedVersions[m.Version] {
continue
}
if m.Version > dbMaxVersion && m.Version <= version {
apply = append(apply, m)
}
}
return apply, nil
}
func (p *Provider) prepareMigration(fsys fs.FS, m *Migration, direction bool) error {
switch m.Type {
case TypeGo:
if m.goUp.Mode == 0 {
return errors.New("go up migration mode is not set")
}
if m.goDown.Mode == 0 {
return errors.New("go down migration mode is not set")
}
var useTx bool
if direction {
useTx = m.goUp.Mode == TransactionEnabled
} else {
useTx = m.goDown.Mode == TransactionEnabled
}
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB,
// but are locking the database with *sql.Conn. If the caller sets max open connections to
// 1, then this will deadlock because the Go migration will try to acquire a connection from
// the pool, but the pool is exhausted because the lock is held.
//
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to
// use *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is
// a bit of an edge case. For now, we guard against this scenario by checking the max open
// connections and returning an error.
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
if !useTx {
return errors.New("potential deadlock detected: cannot run Go migration without a transaction when max open connections set to 1")
}
}
return nil
case TypeSQL:
if m.sql.Parsed {
return nil
}
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source, false)
if err != nil {
return err
}
m.sql.Parsed = true
m.sql.UseTx = parsed.UseTx
m.sql.Up, m.sql.Down = parsed.Up, parsed.Down
return nil
}
return fmt.Errorf("invalid migration type: %+v", m)
}
// runMigrations runs migrations sequentially in the given direction. If the migrations list is
// empty, return nil without error.
func (p *Provider) runMigrations(
ctx context.Context,
conn *sql.Conn,
migrations []*Migration,
direction sqlparser.Direction,
byOne bool,
) ([]*MigrationResult, error) {
if len(migrations) == 0 {
return nil, nil
}
apply := migrations
if byOne {
apply = migrations[:1]
}
// SQL migrations are lazily parsed in both directions. This is done before attempting to run
// any migrations to catch errors early and prevent leaving the database in an incomplete state.
for _, m := range apply {
if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
return nil, err
}
}
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
// to run in a transaction.
// feat(mf): this is where we can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own. See the
// following issues for more details:
// - https://github.com/pressly/goose/issues/485
// - https://github.com/pressly/goose/issues/222
//
// Be careful, we can't use a single transaction for all migrations because some may be marked
// as not using a transaction.
var results []*MigrationResult
for _, m := range apply {
current := &MigrationResult{
Source: &Source{
Type: m.Type,
Path: m.Source,
Version: m.Version,
},
Direction: direction.String(),
Empty: isEmpty(m, direction.ToBool()),
}
start := time.Now()
if err := p.runIndividually(ctx, conn, m, direction.ToBool()); err != nil {
// TODO(mf): we should also return the pending migrations here, the remaining items in
// the apply slice.
current.Error = err
current.Duration = time.Since(start)
return nil, &PartialError{
Applied: results,
Failed: current,
Err: err,
}
}
current.Duration = time.Since(start)
results = append(results, current)
}
return results, nil
}
func (p *Provider) runIndividually(
ctx context.Context,
conn *sql.Conn,
m *Migration,
direction bool,
) error {
useTx, err := useTx(m, direction)
if err != nil {
return err
}
if useTx {
return beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := runMigration(ctx, tx, m, direction); err != nil {
return err
}
return p.maybeInsertOrDelete(ctx, tx, m.Version, direction)
})
}
switch m.Type {
case TypeGo:
// Note, we are using *sql.DB instead of *sql.Conn because it's the Go migration contract.
// This may be a deadlock scenario if max open connections is set to 1 AND a lock is
// acquired on the database. In this case, the migration will block forever unable to
// acquire a connection from the pool.
//
// For now, we guard against this scenario by checking the max open connections and
// returning an error in the prepareMigration function.
if err := runMigration(ctx, p.db, m, direction); err != nil {
return err
}
return p.maybeInsertOrDelete(ctx, p.db, m.Version, direction)
case TypeSQL:
if err := runMigration(ctx, conn, m, direction); err != nil {
return err
}
return p.maybeInsertOrDelete(ctx, conn, m.Version, direction)
}
return fmt.Errorf("failed to run individual migration: neither sql or go: %v", m)
}
func (p *Provider) maybeInsertOrDelete(
ctx context.Context,
db database.DBTxConn,
version int64,
direction bool,
) error {
// If versioning is disabled, we don't need to insert or delete the migration version.
if p.cfg.disableVersioning {
return nil
}
if direction {
return p.store.Insert(ctx, db, database.InsertRequest{Version: version})
}
return p.store.Delete(ctx, db, version)
}
// beginTx begins a transaction and runs the given function. If the function returns an error, the
// transaction is rolled back. Otherwise, the transaction is committed.
func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (retErr error) {
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if retErr != nil {
retErr = multierr.Append(retErr, tx.Rollback())
}
}()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
p.mu.Lock()
conn, err := p.db.Conn(ctx)
if err != nil {
p.mu.Unlock()
return nil, nil, err
}
// cleanup is a function that cleans up the connection, and optionally, the session lock.
cleanup := func() error {
p.mu.Unlock()
return conn.Close()
}
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
if err := l.SessionLock(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
}
// A lock was acquired, so we need to unlock the session when we're done. This is done by
// returning a cleanup function that unlocks the session and closes the connection.
cleanup = func() error {
p.mu.Unlock()
// Use a detached context to unlock the session. This is because the context passed to
// SessionLock may have been canceled, and we don't want to cancel the unlock.
//
// TODO(mf): use context.WithoutCancel added in go1.21
detachedCtx := context.Background()
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
}
}
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
// need the version table because no versions are being recorded.
if !p.cfg.disableVersioning {
if err := p.ensureVersionTable(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
}
}
return conn, cleanup, nil
}
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
}
return beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
}
if p.cfg.disableVersioning {
return nil
}
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
})
}
type missingMigration struct {
versionID int64
}
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
// migration is one that has a version less than the max version in the database.
func checkMissingMigrations(
dbMigrations []*database.ListMigrationsResult,
fsMigrations []*Migration,
) []missingMigration {
existing := make(map[int64]bool)
var dbMaxVersion int64
for _, m := range dbMigrations {
existing[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
var missing []missingMigration
for _, m := range fsMigrations {
version := m.Version
if !existing[version] && version < dbMaxVersion {
missing = append(missing, missingMigration{
versionID: version,
})
}
}
sort.Slice(missing, func(i, j int) bool {
return missing[i].versionID < missing[j].versionID
})
return missing
}
// getMigration returns the migration for the given version. If no migration is found, then
// ErrVersionNotFound is returned.
func (p *Provider) getMigration(version int64) (*Migration, error) {
for _, m := range p.migrations {
if m.Version == version {
return m, nil
}
}
return nil, ErrVersionNotFound
}
// useTx is a helper function that returns true if the migration should be run in a transaction. It
// must only be called after the migration has been parsed and initialized.
func useTx(m *Migration, direction bool) (bool, error) {
switch m.Type {
case TypeGo:
if m.goUp.Mode == 0 || m.goDown.Mode == 0 {
return false, fmt.Errorf("go migrations must have a mode set")
}
if direction {
return m.goUp.Mode == TransactionEnabled, nil
}
return m.goDown.Mode == TransactionEnabled, nil
case TypeSQL:
if !m.sql.Parsed {
return false, fmt.Errorf("sql migrations must be parsed")
}
return m.sql.UseTx, nil
}
return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type)
}
// isEmpty is a helper function that returns true if the migration has no functions or no statements
// to execute. It must only be called after the migration has been parsed and initialized.
func isEmpty(m *Migration, direction bool) bool {
switch m.Type {
case TypeGo:
if direction {
return m.goUp.RunTx == nil && m.goUp.RunDB == nil
}
return m.goDown.RunTx == nil && m.goDown.RunDB == nil
case TypeSQL:
if direction {
return len(m.sql.Up) == 0
}
return len(m.sql.Down) == 0
}
return true
}
// runMigration is a helper function that runs the migration in the given direction. It must only be
// called after the migration has been parsed and initialized.
func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
switch m.Type {
case TypeGo:
return runGo(ctx, db, m, direction)
case TypeSQL:
return runSQL(ctx, db, m, direction)
}
return fmt.Errorf("invalid migration type: %q", m.Type)
}
// runGo is a helper function that runs the given Go functions in the given direction. It must only
// be called after the migration has been initialized.
func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
switch db := db.(type) {
case *sql.Conn:
return fmt.Errorf("go migrations are not supported with *sql.Conn")
case *sql.DB:
if direction && m.goUp.RunDB != nil {
return m.goUp.RunDB(ctx, db)
}
if !direction && m.goDown.RunDB != nil {
return m.goDown.RunDB(ctx, db)
}
return nil
case *sql.Tx:
if direction && m.goUp.RunTx != nil {
return m.goUp.RunTx(ctx, db)
}
if !direction && m.goDown.RunTx != nil {
return m.goDown.RunTx(ctx, db)
}
return nil
}
return fmt.Errorf("invalid database connection type: %T", db)
}
// runSQL is a helper function that runs the given SQL statements in the given direction. It must
// only be called after the migration has been parsed.
func runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
if !m.sql.Parsed {
return fmt.Errorf("sql migrations must be parsed")
}
var statements []string
if direction {
statements = m.sql.Up
} else {
statements = m.sql.Down
}
for _, stmt := range statements {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}

View File

@ -1,4 +1,4 @@
package provider_test
package goose_test
import (
"context"
@ -16,9 +16,9 @@ import (
"testing"
"testing/fstest"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
"github.com/pressly/goose/v3/internal/testdb"
"github.com/pressly/goose/v3/lock"
"golang.org/x/sync/errgroup"
@ -45,22 +45,22 @@ func TestProviderRun(t *testing.T) {
p, _ := newProviderWithDB(t)
_, err := p.ApplyVersion(context.Background(), 999, true)
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true)
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
_, err = p.ApplyVersion(context.Background(), 999, false)
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true)
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
})
t.Run("run_zero", func(t *testing.T) {
p, _ := newProviderWithDB(t)
_, err := p.UpTo(context.Background(), 0)
check.HasError(t, err)
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
check.Equal(t, err.Error(), "version must be greater than 0")
_, err = p.DownTo(context.Background(), -1)
check.HasError(t, err)
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
_, err = p.ApplyVersion(context.Background(), 0, true)
check.HasError(t, err)
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
check.Equal(t, err.Error(), "version must be greater than 0")
})
t.Run("up_and_down_all", func(t *testing.T) {
ctx := context.Background()
@ -72,30 +72,30 @@ func TestProviderRun(t *testing.T) {
check.Number(t, len(sources), numCount)
// Ensure only SQL migrations are returned
for _, s := range sources {
check.Equal(t, s.Type, provider.TypeSQL)
check.Equal(t, s.Type, goose.TypeSQL)
}
// Test Up
res, err := p.Up(ctx)
check.NoError(t, err)
check.Number(t, len(res), numCount)
assertResult(t, res[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, res[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, res[2], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
assertResult(t, res[4], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
assertResult(t, res[5], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
assertResult(t, res[6], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false)
assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "up", false)
assertResult(t, res[4], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "up", false)
assertResult(t, res[5], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "up", true)
assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
// Test Down
res, err = p.DownTo(ctx, 0)
check.NoError(t, err)
check.Number(t, len(res), numCount)
assertResult(t, res[0], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
assertResult(t, res[1], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
assertResult(t, res[2], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
assertResult(t, res[4], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
assertResult(t, res[5], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
assertResult(t, res[6], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true)
assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false)
assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "down", false)
assertResult(t, res[4], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "down", false)
assertResult(t, res[5], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "down", false)
assertResult(t, res[6], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "down", false)
})
t.Run("up_and_down_by_one", func(t *testing.T) {
ctx := context.Background()
@ -107,8 +107,8 @@ func TestProviderRun(t *testing.T) {
res, err := p.UpByOne(ctx)
counter++
if counter > maxVersion {
if !errors.Is(err, provider.ErrNoNextVersion) {
t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion)
if !errors.Is(err, goose.ErrNoNextVersion) {
t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion)
}
break
}
@ -126,8 +126,8 @@ func TestProviderRun(t *testing.T) {
res, err := p.Down(ctx)
counter++
if counter > maxVersion {
if !errors.Is(err, provider.ErrNoNextVersion) {
t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion)
if !errors.Is(err, goose.ErrNoNextVersion) {
t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion)
}
break
}
@ -149,14 +149,14 @@ func TestProviderRun(t *testing.T) {
results, err := p.UpTo(ctx, upToVersion)
check.NoError(t, err)
check.Number(t, len(results), upToVersion)
assertResult(t, results[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, results[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
// Fetch the goose version from DB
currentVersion, err := p.GetDBVersion(ctx)
check.NoError(t, err)
check.Number(t, currentVersion, upToVersion)
// Validate the version actually matches what goose claims it is
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
check.NoError(t, err)
check.Number(t, gotVersion, upToVersion)
})
@ -197,7 +197,7 @@ func TestProviderRun(t *testing.T) {
check.NoError(t, err)
check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version)
// Validate the db migration version actually matches what goose claims it is
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
check.NoError(t, err)
check.Number(t, gotVersion, currentVersion)
tables, err := getTableNames(db)
@ -213,13 +213,13 @@ func TestProviderRun(t *testing.T) {
downResult, err := p.DownTo(ctx, 0)
check.NoError(t, err)
check.Number(t, len(downResult), len(sources))
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
check.NoError(t, err)
check.Number(t, gotVersion, 0)
// Should only be left with a single table, the default goose table
tables, err := getTableNames(db)
check.NoError(t, err)
knownTables := []string{provider.DefaultTablename, "sqlite_sequence"}
knownTables := []string{goose.DefaultTablename, "sqlite_sequence"}
if !reflect.DeepEqual(tables, knownTables) {
t.Logf("got tables: %v", tables)
t.Logf("known tables: %v", knownTables)
@ -261,7 +261,7 @@ func TestProviderRun(t *testing.T) {
check.NoError(t, err)
_, err = p.ApplyVersion(ctx, 1, true)
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true)
check.Bool(t, errors.Is(err, goose.ErrAlreadyApplied), true)
check.Contains(t, err.Error(), "version 1: already applied")
})
t.Run("status", func(t *testing.T) {
@ -272,26 +272,26 @@ func TestProviderRun(t *testing.T) {
status, err := p.Status(ctx)
check.NoError(t, err)
check.Number(t, len(status), numCount)
assertStatus(t, status[0], provider.StatePending, newSource(provider.TypeSQL, "00001_users_table.sql", 1), true)
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), true)
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), true)
assertStatus(t, status[3], provider.StatePending, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), true)
assertStatus(t, status[4], provider.StatePending, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), true)
assertStatus(t, status[5], provider.StatePending, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), true)
assertStatus(t, status[6], provider.StatePending, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true)
assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true)
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true)
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true)
assertStatus(t, status[3], goose.StatePending, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), true)
assertStatus(t, status[4], goose.StatePending, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), true)
assertStatus(t, status[5], goose.StatePending, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), true)
assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true)
// Apply all migrations
_, err = p.Up(ctx)
check.NoError(t, err)
status, err = p.Status(ctx)
check.NoError(t, err)
check.Number(t, len(status), numCount)
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], provider.StateApplied, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), false)
assertStatus(t, status[2], provider.StateApplied, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), false)
assertStatus(t, status[3], provider.StateApplied, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), false)
assertStatus(t, status[4], provider.StateApplied, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), false)
assertStatus(t, status[5], provider.StateApplied, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), false)
assertStatus(t, status[6], provider.StateApplied, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false)
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false)
assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false)
assertStatus(t, status[3], goose.StateApplied, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), false)
assertStatus(t, status[4], goose.StateApplied, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), false)
assertStatus(t, status[5], goose.StateApplied, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), false)
assertStatus(t, status[6], goose.StateApplied, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), false)
})
t.Run("tx_partial_errors", func(t *testing.T) {
countOwners := func(db *sql.DB) (int, error) {
@ -321,22 +321,22 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2');
INSERT INTO owners (owner_name) VALUES ('seed-user-3');
`),
}
p, err := provider.NewProvider(database.DialectSQLite3, db, mapFS)
p, err := goose.NewProvider(database.DialectSQLite3, db, mapFS)
check.NoError(t, err)
_, err = p.Up(ctx)
check.HasError(t, err)
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
var expected *provider.PartialError
var expected *goose.PartialError
check.Bool(t, errors.As(err, &expected), true)
// Check Err field
check.Bool(t, expected.Err != nil, true)
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
// Check Results field
check.Number(t, len(expected.Applied), 1)
assertResult(t, expected.Applied[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
// Check Failed field
check.Bool(t, expected.Failed != nil, true)
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2)
check.Bool(t, expected.Failed.Empty, false)
check.Bool(t, expected.Failed.Error != nil, true)
check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)")
@ -351,9 +351,9 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
status, err := p.Status(ctx)
check.NoError(t, err)
check.Number(t, len(status), 3)
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_partial_error.sql", 2), true)
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_insert_data.sql", 3), true)
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true)
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true)
})
}
@ -415,7 +415,7 @@ func TestConcurrentProvider(t *testing.T) {
check.NoError(t, err)
check.Number(t, currentVersion, maxVersion)
ch := make(chan []*provider.MigrationResult)
ch := make(chan []*goose.MigrationResult)
var wg sync.WaitGroup
for i := 0; i < maxVersion; i++ {
wg.Add(1)
@ -435,8 +435,8 @@ func TestConcurrentProvider(t *testing.T) {
close(ch)
}()
var (
valid [][]*provider.MigrationResult
empty [][]*provider.MigrationResult
valid [][]*goose.MigrationResult
empty [][]*goose.MigrationResult
)
for results := range ch {
if len(results) == 0 {
@ -486,9 +486,9 @@ func TestNoVersioning(t *testing.T) {
// These are owners created by migration files.
wantOwnerCount = 4
)
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
provider.WithVerbose(testing.Verbose()),
provider.WithDisabledVersioning(false), // This is the default.
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
goose.WithVerbose(testing.Verbose()),
goose.WithDisableVersioning(false), // This is the default.
)
check.Number(t, len(p.ListSources()), 3)
check.NoError(t, err)
@ -499,9 +499,9 @@ func TestNoVersioning(t *testing.T) {
check.Number(t, baseVersion, 3)
t.Run("seed-up-down-to-zero", func(t *testing.T) {
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
provider.WithVerbose(testing.Verbose()),
provider.WithDisabledVersioning(true), // Provider with no versioning.
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
goose.WithVerbose(testing.Verbose()),
goose.WithDisableVersioning(true), // Provider with no versioning.
)
check.NoError(t, err)
check.Number(t, len(p.ListSources()), 2)
@ -552,8 +552,8 @@ func TestAllowMissing(t *testing.T) {
t.Run("missing_now_allowed", func(t *testing.T) {
db := newDB(t)
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
provider.WithAllowedMissing(false),
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
goose.WithAllowOutofOrder(false),
)
check.NoError(t, err)
@ -607,8 +607,8 @@ func TestAllowMissing(t *testing.T) {
t.Run("missing_allowed", func(t *testing.T) {
db := newDB(t)
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
provider.WithAllowedMissing(true),
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
goose.WithAllowOutofOrder(true),
)
check.NoError(t, err)
@ -640,7 +640,7 @@ func TestAllowMissing(t *testing.T) {
check.Bool(t, upResult != nil, true)
check.Number(t, upResult.Source.Version, 6)
count, err := getGooseVersionCount(db, provider.DefaultTablename)
count, err := getGooseVersionCount(db, goose.DefaultTablename)
check.NoError(t, err)
check.Number(t, count, 6)
current, err := p.GetDBVersion(ctx)
@ -676,7 +676,7 @@ func TestAllowMissing(t *testing.T) {
testDownAndVersion(1, 1)
_, err = p.Down(ctx)
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true)
check.Bool(t, errors.Is(err, goose.ErrNoNextVersion), true)
})
}
@ -691,6 +691,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
}
func TestGoOnly(t *testing.T) {
t.Cleanup(goose.ResetGlobalMigrations)
// Not parallel because each subtest modifies global state.
countUser := func(db *sql.DB) int {
@ -703,99 +704,109 @@ func TestGoOnly(t *testing.T) {
t.Run("with_tx", func(t *testing.T) {
ctx := context.Background()
register := []*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
DownFnContext: newTxFn("DROP TABLE users"),
},
register := []*goose.Migration{
goose.NewGoMigration(
1,
&goose.GoFunc{RunTx: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)")},
&goose.GoFunc{RunTx: newTxFn("DROP TABLE users")},
),
}
err := provider.SetGlobalGoMigrations(register)
err := goose.SetGlobalMigrations(register...)
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
t.Cleanup(goose.ResetGlobalMigrations)
db := newDB(t)
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(
register = []*goose.Migration{
goose.NewGoMigration(
2,
&provider.GoMigrationFunc{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&provider.GoMigrationFunc{Run: newTxFn("DELETE FROM users")},
&goose.GoFunc{RunTx: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&goose.GoFunc{RunTx: newTxFn("DELETE FROM users")},
),
}
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
goose.WithGoMigrations(register...),
)
check.NoError(t, err)
sources := p.ListSources()
check.Number(t, len(p.ListSources()), 2)
assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1)
assertSource(t, sources[1], provider.TypeGo, "", 2)
assertSource(t, sources[0], goose.TypeGo, "", 1)
assertSource(t, sources[1], goose.TypeGo, "", 2)
// Apply migration 1
res, err := p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2
res, err = p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
check.Number(t, countUser(db), 3)
// Rollback migration 2
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
check.Number(t, countUser(db), 0)
// Rollback migration 1
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
// Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false)
})
t.Run("with_db", func(t *testing.T) {
ctx := context.Background()
register := []*provider.MigrationCopy{
{
Version: 1, Source: "00001_users_table.go", Registered: true,
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
DownFnNoTxContext: newDBFn("DROP TABLE users"),
},
register := []*goose.Migration{
goose.NewGoMigration(
1,
&goose.GoFunc{
RunDB: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
},
&goose.GoFunc{
RunDB: newDBFn("DROP TABLE users"),
},
),
}
err := provider.SetGlobalGoMigrations(register)
err := goose.SetGlobalMigrations(register...)
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
t.Cleanup(goose.ResetGlobalMigrations)
db := newDB(t)
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
provider.WithGoMigration(
register = []*goose.Migration{
goose.NewGoMigration(
2,
&provider.GoMigrationFunc{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&provider.GoMigrationFunc{RunNoTx: newDBFn("DELETE FROM users")},
&goose.GoFunc{RunDB: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
&goose.GoFunc{RunDB: newDBFn("DELETE FROM users")},
),
}
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
goose.WithGoMigrations(register...),
)
check.NoError(t, err)
sources := p.ListSources()
check.Number(t, len(p.ListSources()), 2)
assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1)
assertSource(t, sources[1], provider.TypeGo, "", 2)
assertSource(t, sources[0], goose.TypeGo, "", 1)
assertSource(t, sources[1], goose.TypeGo, "", 2)
// Apply migration 1
res, err := p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
check.Number(t, countUser(db), 0)
check.Bool(t, tableExists(t, db, "users"), true)
// Apply migration 2
res, err = p.UpByOne(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
check.Number(t, countUser(db), 3)
// Rollback migration 2
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
check.Number(t, countUser(db), 0)
// Rollback migration 1
res, err = p.Down(ctx)
check.NoError(t, err)
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
// Check table does not exist
check.Bool(t, tableExists(t, db, "users"), false)
})
@ -818,16 +829,23 @@ func TestLockModeAdvisorySession(t *testing.T) {
check.NoError(t, err)
t.Cleanup(cleanup)
newProvider := func() *provider.Provider {
sessionLocker, err := lock.NewPostgresSessionLocker()
check.NoError(t, err)
p, err := provider.NewProvider(database.DialectPostgres, db, os.DirFS("../../testdata/migrations"),
provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
provider.WithVerbose(testing.Verbose()),
newProvider := func() *goose.Provider {
sessionLocker, err := lock.NewPostgresSessionLocker(
lock.WithLockTimeout(5, 60), // Timeout 5min. Try every 5s up to 60 times.
)
check.NoError(t, err)
p, err := goose.NewProvider(
database.DialectPostgres,
db,
os.DirFS("testdata/migrations"),
goose.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
)
check.NoError(t, err)
return p
}
provider1 := newProvider()
provider2 := newProvider()
@ -891,7 +909,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
for {
result, err := provider1.UpByOne(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
if errors.Is(err, goose.ErrNoNextVersion) {
return nil
}
return err
@ -907,7 +925,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
for {
result, err := provider2.UpByOne(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
if errors.Is(err, goose.ErrNoNextVersion) {
return nil
}
return err
@ -993,7 +1011,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
for {
result, err := provider1.Down(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
if errors.Is(err, goose.ErrNoNextVersion) {
return nil
}
return err
@ -1009,7 +1027,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
for {
result, err := provider2.Down(context.Background())
if err != nil {
if errors.Is(err, provider.ErrNoNextVersion) {
if errors.Is(err, goose.ErrNoNextVersion) {
return nil
}
return err
@ -1068,14 +1086,14 @@ func randomAlphaNumeric(length int) string {
return string(b)
}
func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) {
func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provider, *sql.DB) {
t.Helper()
db := newDB(t)
opts = append(
opts,
provider.WithVerbose(testing.Verbose()),
goose.WithVerbose(testing.Verbose()),
)
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
check.NoError(t, err)
return p, db
}
@ -1118,14 +1136,14 @@ func getTableNames(db *sql.DB) ([]string, error) {
return tables, nil
}
func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) {
func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) {
t.Helper()
check.Equal(t, got.State, state)
check.Equal(t, got.Source, source)
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
}
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) {
t.Helper()
check.Bool(t, got != nil, true)
check.Equal(t, got.Source, source)
@ -1135,21 +1153,15 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S
check.Bool(t, got.Duration > 0, true)
}
func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) {
func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) {
t.Helper()
check.Equal(t, got.Type, typ)
check.Equal(t, got.Path, name)
check.Equal(t, got.Version, version)
switch got.Type {
case provider.TypeGo:
check.Equal(t, got.Type.String(), "go")
case provider.TypeSQL:
check.Equal(t, got.Type.String(), "sql")
}
}
func newSource(t provider.MigrationType, fullpath string, version int64) provider.Source {
return provider.Source{
func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source {
return &goose.Source{
Type: t,
Path: fullpath,
Version: version,

79
provider_test.go Normal file
View File

@ -0,0 +1,79 @@
package goose_test
import (
"database/sql"
"errors"
"io/fs"
"path/filepath"
"testing"
"testing/fstest"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
_ "modernc.org/sqlite"
)
func TestProvider(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
t.Run("empty", func(t *testing.T) {
_, err := goose.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
check.HasError(t, err)
check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true)
})
mapFS := fstest.MapFS{
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys)
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1))
check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2))
}
var (
migration1 = `
-- +goose Up
CREATE TABLE foo (id INTEGER PRIMARY KEY);
-- +goose Down
DROP TABLE foo;
`
migration2 = `
-- +goose Up
ALTER TABLE foo ADD COLUMN name TEXT;
-- +goose Down
ALTER TABLE foo DROP COLUMN name;
`
migration3 = `
-- +goose Up
CREATE TABLE bar (
id INTEGER PRIMARY KEY,
description TEXT
);
-- +goose Down
DROP TABLE bar;
`
migration4 = `
-- +goose Up
-- Rename the 'foo' table to 'my_foo'
ALTER TABLE foo RENAME TO my_foo;
-- Add a new column 'timestamp' to 'my_foo'
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
-- +goose Down
-- Remove the 'timestamp' column from 'my_foo'
ALTER TABLE my_foo DROP COLUMN timestamp;
-- Rename the 'my_foo' table back to 'foo'
ALTER TABLE my_foo RENAME TO foo;
`
)

View File

@ -1,30 +1,15 @@
package provider
package goose
import (
"fmt"
"time"
)
import "time"
// MigrationType is the type of migration.
type MigrationType int
type MigrationType string
const (
TypeGo MigrationType = iota + 1
TypeSQL
TypeGo MigrationType = "go"
TypeSQL MigrationType = "sql"
)
func (t MigrationType) String() string {
switch t {
case TypeGo:
return "go"
case TypeSQL:
return "sql"
default:
// This should never happen.
return fmt.Sprintf("unknown (%d)", t)
}
}
// Source represents a single migration source.
//
// The Path field may be empty if the migration was registered manually. This is typically the case
@ -37,7 +22,7 @@ type Source struct {
// MigrationResult is the result of a single migration operation.
type MigrationResult struct {
Source Source
Source *Source
Duration time.Duration
Direction string
// Empty indicates no action was taken during the migration, but it was still versioned. For
@ -64,7 +49,7 @@ const (
// MigrationStatus represents the status of a single migration.
type MigrationStatus struct {
Source Source
Source *Source
State State
AppliedAt time.Time
}

View File

@ -66,7 +66,7 @@ func register(filename string, useTx bool, up, down *GoFunc) error {
// We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but
// we know based on the register function what the user is requesting.
m.UseTx = useTx
registeredGoMigrations[v] = &m
registeredGoMigrations[v] = m
return nil
}

View File

@ -1,17 +0,0 @@
package goose
// MigrationType is the type of migration.
type MigrationType string
const (
TypeGo MigrationType = "go"
TypeSQL MigrationType = "sql"
)
func (t MigrationType) String() string {
// This should never happen.
if t == "" {
return "unknown migration type"
}
return string(t)
}