mirror of https://github.com/pressly/goose.git
feat: Add goose provider (#635)
parent
8503d4e20b
commit
04e12b88f4
|
@ -2,6 +2,7 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -100,6 +101,9 @@ func (s *store) GetMigration(
|
|||
&result.Timestamp,
|
||||
&result.IsApplied,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
|
||||
}
|
||||
return &result, nil
|
||||
|
|
|
@ -2,9 +2,15 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
|
||||
ErrVersionNotFound = errors.New("version not found")
|
||||
)
|
||||
|
||||
// Store is an interface that defines methods for managing database migrations and versioning. By
|
||||
// defining a Store interface, we can support multiple databases with consistent functionality.
|
||||
//
|
||||
|
@ -24,8 +30,8 @@ type Store interface {
|
|||
// Delete deletes a version id from the version table.
|
||||
Delete(ctx context.Context, db DBTxConn, version int64) error
|
||||
|
||||
// GetMigration retrieves a single migration by version id. This method may return the raw sql
|
||||
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
|
||||
// GetMigration retrieves a single migration by version id. If the query succeeds, but the
|
||||
// version is not found, this method must return [ErrVersionNotFound].
|
||||
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
|
||||
|
||||
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
|
||||
|
|
|
@ -205,7 +205,7 @@ func testStore(
|
|||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||
_, err := store.GetMigration(ctx, conn, 0)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
|
||||
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
|
||||
return nil
|
||||
})
|
||||
check.NoError(t, err)
|
||||
|
|
60
globals.go
60
globals.go
|
@ -22,13 +22,12 @@ func ResetGlobalMigrations() {
|
|||
// [NewGoMigration] function.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func SetGlobalMigrations(migrations ...Migration) error {
|
||||
for _, migration := range migrations {
|
||||
m := &migration
|
||||
func SetGlobalMigrations(migrations ...*Migration) error {
|
||||
for _, m := range migrations {
|
||||
if _, ok := registeredGoMigrations[m.Version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||
}
|
||||
if err := checkMigration(m); err != nil {
|
||||
if err := checkGoMigration(m); err != nil {
|
||||
return fmt.Errorf("invalid go migration: %w", err)
|
||||
}
|
||||
registeredGoMigrations[m.Version] = m
|
||||
|
@ -36,7 +35,7 @@ func SetGlobalMigrations(migrations ...Migration) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkMigration(m *Migration) error {
|
||||
func checkGoMigration(m *Migration) error {
|
||||
if !m.construct {
|
||||
return errors.New("must use NewGoMigration to construct migrations")
|
||||
}
|
||||
|
@ -63,10 +62,10 @@ func checkMigration(m *Migration) error {
|
|||
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
|
||||
}
|
||||
}
|
||||
if err := setGoFunc(m.goUp); err != nil {
|
||||
if err := checkGoFunc(m.goUp); err != nil {
|
||||
return fmt.Errorf("up function: %w", err)
|
||||
}
|
||||
if err := setGoFunc(m.goDown); err != nil {
|
||||
if err := checkGoFunc(m.goDown); err != nil {
|
||||
return fmt.Errorf("down function: %w", err)
|
||||
}
|
||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||
|
@ -84,47 +83,22 @@ func checkMigration(m *Migration) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func setGoFunc(f *GoFunc) error {
|
||||
if f == nil {
|
||||
f = &GoFunc{Mode: TransactionEnabled}
|
||||
return nil
|
||||
}
|
||||
func checkGoFunc(f *GoFunc) error {
|
||||
if f.RunTx != nil && f.RunDB != nil {
|
||||
return errors.New("must specify exactly one of RunTx or RunDB")
|
||||
}
|
||||
if f.RunTx == nil && f.RunDB == nil {
|
||||
switch f.Mode {
|
||||
case 0:
|
||||
// Default to TransactionEnabled ONLY if mode is not set explicitly.
|
||||
f.Mode = TransactionEnabled
|
||||
case TransactionEnabled, TransactionDisabled:
|
||||
// No functions but mode is set. This is not an error. It means the user wants to record
|
||||
// a version with the given mode but not run any functions.
|
||||
default:
|
||||
return fmt.Errorf("invalid mode: %d", f.Mode)
|
||||
}
|
||||
return nil
|
||||
switch f.Mode {
|
||||
case TransactionEnabled, TransactionDisabled:
|
||||
// No functions, but mode is set. This is not an error. It means the user wants to
|
||||
// record a version with the given mode but not run any functions.
|
||||
default:
|
||||
return fmt.Errorf("invalid mode: %d", f.Mode)
|
||||
}
|
||||
if f.RunDB != nil {
|
||||
switch f.Mode {
|
||||
case 0, TransactionDisabled:
|
||||
f.Mode = TransactionDisabled
|
||||
default:
|
||||
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
|
||||
}
|
||||
if f.RunDB != nil && f.Mode != TransactionDisabled {
|
||||
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
|
||||
}
|
||||
if f.RunTx != nil {
|
||||
switch f.Mode {
|
||||
case 0, TransactionEnabled:
|
||||
f.Mode = TransactionEnabled
|
||||
default:
|
||||
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
|
||||
}
|
||||
}
|
||||
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
|
||||
// the functions or return an error. This should never happen.
|
||||
if f.Mode == 0 {
|
||||
return errors.New("failed to infer transaction mode")
|
||||
if f.RunTx != nil && f.Mode != TransactionEnabled {
|
||||
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) {
|
|||
// reset so we can check the default is set
|
||||
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
|
||||
err = SetGlobalMigrations(migration2)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(registeredGoMigrations), 2)
|
||||
registered = registeredGoMigrations[2]
|
||||
check.Bool(t, registered.goUp != nil, true)
|
||||
check.Bool(t, registered.goDown != nil, true)
|
||||
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
|
||||
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")
|
||||
|
||||
migration3 := NewGoMigration(3, nil, nil)
|
||||
// reset so we can check the default is set
|
||||
migration3.goDown.Mode = 0
|
||||
err = SetGlobalMigrations(migration3)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
|
||||
})
|
||||
t.Run("unknown_mode", func(t *testing.T) {
|
||||
m := NewGoMigration(1, nil, nil)
|
||||
|
@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) {
|
|||
runTx := func(context.Context, *sql.Tx) error { return nil }
|
||||
|
||||
// Success.
|
||||
err := SetGlobalMigrations([]Migration{}...)
|
||||
err := SetGlobalMigrations([]*Migration{}...)
|
||||
check.NoError(t, err)
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
||||
|
@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) {
|
|||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "go migration with version 1 already registered")
|
||||
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
|
||||
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||
}
|
||||
|
||||
func TestCheckMigration(t *testing.T) {
|
||||
// Failures.
|
||||
err := checkMigration(&Migration{})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||
err = checkMigration(&Migration{construct: true})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must be registered")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `type must be "go"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "version must be greater than zero")
|
||||
// Success.
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
|
||||
err := checkGoMigration(NewGoMigration(1, nil, nil))
|
||||
check.NoError(t, err)
|
||||
// Failures.
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
||||
err = checkGoMigration(&Migration{})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||
err = checkGoMigration(&Migration{construct: true})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must be registered")
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `type must be "go"`)
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "version must be greater than zero")
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "up function: invalid mode: 0")
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "down function: invalid mode: 0")
|
||||
// Success.
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
|
||||
check.NoError(t, err)
|
||||
// Failures.
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `no filename separator '_' found`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
UpFn: func(*sql.Tx) error { return nil },
|
||||
UpFnNoTx: func(*sql.DB) error { return nil },
|
||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
DownFn: func(*sql.Tx) error { return nil },
|
||||
DownFnNoTx: func(*sql.DB) error { return nil },
|
||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
|
||||
|
|
|
@ -1,186 +0,0 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pressly/goose/v3/database"
|
||||
)
|
||||
|
||||
type migration struct {
|
||||
Source Source
|
||||
// A migration is either a Go migration or a SQL migration, but never both.
|
||||
//
|
||||
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
|
||||
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
|
||||
// majority of the time migrations are incremental, so it is likely that the user will only want
|
||||
// to run the last few migrations, and there is no need to parse ALL prior migrations.
|
||||
//
|
||||
// Exactly one of these fields will be set:
|
||||
Go *goMigration
|
||||
// -- OR --
|
||||
SQL *sqlMigration
|
||||
}
|
||||
|
||||
func (m *migration) useTx(direction bool) bool {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
return m.SQL.UseTx
|
||||
case TypeGo:
|
||||
if m.Go == nil || m.Go.isEmpty(direction) {
|
||||
return false
|
||||
}
|
||||
if direction {
|
||||
return m.Go.up.Run != nil
|
||||
}
|
||||
return m.Go.down.Run != nil
|
||||
}
|
||||
// This should never happen.
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *migration) isEmpty(direction bool) bool {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
return m.SQL == nil || m.SQL.isEmpty(direction)
|
||||
case TypeGo:
|
||||
return m.Go == nil || m.Go.isEmpty(direction)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *migration) filename() string {
|
||||
return filepath.Base(m.Source.Path)
|
||||
}
|
||||
|
||||
// run runs the migration inside of a transaction.
|
||||
func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil {
|
||||
return fmt.Errorf("tx: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, tx, direction)
|
||||
case TypeGo:
|
||||
return m.Go.run(ctx, tx, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
||||
}
|
||||
|
||||
// runNoTx runs the migration without a transaction.
|
||||
func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil {
|
||||
return fmt.Errorf("db: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, db, direction)
|
||||
case TypeGo:
|
||||
return m.Go.runNoTx(ctx, db, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
||||
}
|
||||
|
||||
// runConn runs the migration without a transaction using the provided connection.
|
||||
func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error {
|
||||
switch m.Source.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil {
|
||||
return fmt.Errorf("conn: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, conn, direction)
|
||||
case TypeGo:
|
||||
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
||||
}
|
||||
|
||||
type goMigration struct {
|
||||
fullpath string
|
||||
up, down *GoMigrationFunc
|
||||
}
|
||||
|
||||
func (g *goMigration) isEmpty(direction bool) bool {
|
||||
if g.up == nil && g.down == nil {
|
||||
panic("go migration has no up or down")
|
||||
}
|
||||
if direction {
|
||||
return g.up == nil
|
||||
}
|
||||
return g.down == nil
|
||||
}
|
||||
|
||||
func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration {
|
||||
return &goMigration{
|
||||
fullpath: fullpath,
|
||||
up: up,
|
||||
down: down,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
var fn func(context.Context, *sql.Tx) error
|
||||
if direction && g.up != nil {
|
||||
fn = g.up.Run
|
||||
}
|
||||
if !direction && g.down != nil {
|
||||
fn = g.down.Run
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, tx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
var fn func(context.Context, *sql.DB) error
|
||||
if direction && g.up != nil {
|
||||
fn = g.up.RunNoTx
|
||||
}
|
||||
if !direction && g.down != nil {
|
||||
fn = g.down.RunNoTx
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type sqlMigration struct {
|
||||
UseTx bool
|
||||
UpStatements []string
|
||||
DownStatements []string
|
||||
}
|
||||
|
||||
func (s *sqlMigration) isEmpty(direction bool) bool {
|
||||
if direction {
|
||||
return len(s.UpStatements) == 0
|
||||
}
|
||||
return len(s.DownStatements) == 0
|
||||
}
|
||||
|
||||
func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error {
|
||||
var statements []string
|
||||
if direction {
|
||||
statements = s.UpStatements
|
||||
} else {
|
||||
statements = s.DownStatements
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,68 +0,0 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
type MigrationCopy struct {
|
||||
Version int64
|
||||
Source string // path to .sql script or go file
|
||||
Registered bool
|
||||
UpFnContext, DownFnContext func(context.Context, *sql.Tx) error
|
||||
UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
|
||||
}
|
||||
|
||||
var registeredGoMigrations = make(map[int64]*MigrationCopy)
|
||||
|
||||
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
|
||||
// the migrations are nil or if a migration with the same version has already been registered.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func SetGlobalGoMigrations(migrations []*MigrationCopy) error {
|
||||
for _, m := range migrations {
|
||||
if m == nil {
|
||||
return errors.New("cannot register nil go migration")
|
||||
}
|
||||
if m.Version < 1 {
|
||||
return errors.New("migration versions must be greater than zero")
|
||||
}
|
||||
if !m.Registered {
|
||||
return errors.New("migration must be registered")
|
||||
}
|
||||
if m.Source != "" {
|
||||
// If the source is set, expect it to be a file path with a numeric component that
|
||||
// matches the version.
|
||||
version, err := goose.NumericComponent(m.Source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if version != m.Version {
|
||||
return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source)
|
||||
}
|
||||
}
|
||||
// It's valid for all of these to be nil.
|
||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
}
|
||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
||||
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
}
|
||||
if _, ok := registeredGoMigrations[m.Version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||
}
|
||||
registeredGoMigrations[m.Version] = m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetGlobalGoMigrations resets the global go migrations registry.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func ResetGlobalGoMigrations() {
|
||||
registeredGoMigrations = make(map[int64]*MigrationCopy)
|
||||
}
|
|
@ -1,272 +0,0 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/pressly/goose/v3/database"
|
||||
)
|
||||
|
||||
// Provider is a goose migration provider.
|
||||
type Provider struct {
|
||||
// mu protects all accesses to the provider and must be held when calling operations on the
|
||||
// database.
|
||||
mu sync.Mutex
|
||||
|
||||
db *sql.DB
|
||||
fsys fs.FS
|
||||
cfg config
|
||||
store database.Store
|
||||
|
||||
// migrations are ordered by version in ascending order.
|
||||
migrations []*migration
|
||||
}
|
||||
|
||||
// NewProvider returns a new goose Provider.
|
||||
//
|
||||
// The caller is responsible for matching the database dialect with the database/sql driver. For
|
||||
// example, if the database dialect is "postgres", the database/sql driver could be
|
||||
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
|
||||
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
|
||||
// as using a custom table name or supplying a custom store implementation, see [WithStore].
|
||||
//
|
||||
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
|
||||
// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
|
||||
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
|
||||
// migrations using [fs.Sub].
|
||||
//
|
||||
// See [ProviderOption] for more information on configuring the provider.
|
||||
//
|
||||
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
|
||||
//
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("db must not be nil")
|
||||
}
|
||||
if fsys == nil {
|
||||
fsys = noopFS{}
|
||||
}
|
||||
cfg := config{
|
||||
registered: make(map[int64]*goMigration),
|
||||
excludes: make(map[string]bool),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Allow users to specify a custom store implementation, but only if they don't specify a
|
||||
// dialect. If they specify a dialect, we'll use the default store implementation.
|
||||
if dialect == "" && cfg.store == nil {
|
||||
return nil, errors.New("dialect must not be empty")
|
||||
}
|
||||
if dialect != "" && cfg.store != nil {
|
||||
return nil, errors.New("cannot set both dialect and custom store")
|
||||
}
|
||||
var store database.Store
|
||||
if dialect != "" {
|
||||
var err error
|
||||
store, err = database.NewStore(dialect, DefaultTablename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
store = cfg.store
|
||||
}
|
||||
if store.Tablename() == "" {
|
||||
return nil, errors.New("invalid store implementation: table name must not be empty")
|
||||
}
|
||||
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
|
||||
}
|
||||
|
||||
func newProvider(
|
||||
db *sql.DB,
|
||||
store database.Store,
|
||||
fsys fs.FS,
|
||||
cfg config,
|
||||
global map[int64]*MigrationCopy,
|
||||
) (*Provider, error) {
|
||||
// Collect migrations from the filesystem and merge with registered migrations.
|
||||
//
|
||||
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
|
||||
// lazily.
|
||||
//
|
||||
// TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to
|
||||
// return an error if there are any SQL parsing errors. This adds a bit overhead to startup
|
||||
// though, so we should make it optional.
|
||||
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registered := make(map[int64]*goMigration)
|
||||
// Add user-registered Go migrations.
|
||||
for version, m := range cfg.registered {
|
||||
registered[version] = newGoMigration("", m.up, m.down)
|
||||
}
|
||||
// Add init() functions. This is a bit ugly because we need to convert from the old Migration
|
||||
// struct to the new goMigration struct.
|
||||
for version, m := range global {
|
||||
if _, ok := registered[version]; ok {
|
||||
return nil, fmt.Errorf("go migration with version %d already registered", version)
|
||||
}
|
||||
if m == nil {
|
||||
return nil, errors.New("registered migration with nil init function")
|
||||
}
|
||||
g := newGoMigration(m.Source, nil, nil)
|
||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||
return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext")
|
||||
}
|
||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
||||
return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext")
|
||||
}
|
||||
// Up
|
||||
if m.UpFnContext != nil {
|
||||
g.up = &GoMigrationFunc{
|
||||
Run: m.UpFnContext,
|
||||
}
|
||||
} else if m.UpFnNoTxContext != nil {
|
||||
g.up = &GoMigrationFunc{
|
||||
RunNoTx: m.UpFnNoTxContext,
|
||||
}
|
||||
}
|
||||
// Down
|
||||
if m.DownFnContext != nil {
|
||||
g.down = &GoMigrationFunc{
|
||||
Run: m.DownFnContext,
|
||||
}
|
||||
} else if m.DownFnNoTxContext != nil {
|
||||
g.down = &GoMigrationFunc{
|
||||
RunNoTx: m.DownFnNoTxContext,
|
||||
}
|
||||
}
|
||||
registered[version] = g
|
||||
}
|
||||
migrations, err := merge(filesystemSources, registered)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(migrations) == 0 {
|
||||
return nil, ErrNoMigrations
|
||||
}
|
||||
return &Provider{
|
||||
db: db,
|
||||
fsys: fsys,
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
migrations: migrations,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Status returns the status of all migrations, merging the list of migrations from the database and
|
||||
// filesystem. The returned items are ordered by version, in ascending order.
|
||||
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
||||
return p.status(ctx)
|
||||
}
|
||||
|
||||
// GetDBVersion returns the max version from the database, regardless of the applied order. For
|
||||
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
|
||||
// applied, it returns 0.
|
||||
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
||||
return p.getDBVersion(ctx)
|
||||
}
|
||||
|
||||
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
|
||||
// ascending order by version.
|
||||
func (p *Provider) ListSources() []Source {
|
||||
sources := make([]Source, 0, len(p.migrations))
|
||||
for _, m := range p.migrations {
|
||||
sources = append(sources, m.Source)
|
||||
}
|
||||
return sources
|
||||
}
|
||||
|
||||
// Ping attempts to ping the database to verify a connection is available.
|
||||
func (p *Provider) Ping(ctx context.Context) error {
|
||||
return p.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (p *Provider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// ApplyVersion applies exactly one migration by version. If there is no source for the specified
|
||||
// version, this method returns [ErrVersionNotFound]. If the migration has been applied already,
|
||||
// this method returns [ErrAlreadyApplied].
|
||||
//
|
||||
// When direction is true, the up migration is executed, and when direction is false, the down
|
||||
// migration is executed.
|
||||
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
|
||||
if version < 1 {
|
||||
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
|
||||
}
|
||||
return p.apply(ctx, version, direction)
|
||||
}
|
||||
|
||||
// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method
|
||||
// returns empty list and nil error.
|
||||
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
|
||||
return p.up(ctx, false, math.MaxInt64)
|
||||
}
|
||||
|
||||
// UpByOne applies the next available migration. If there are no migrations to apply, this method
|
||||
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
|
||||
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
||||
res, err := p.up(ctx, true, math.MaxInt64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
// This should never happen. We should always have exactly one result and test for this.
|
||||
if len(res) > 1 {
|
||||
return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res))
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// UpTo applies all available migrations up to, and including, the specified version. If there are
|
||||
// no migrations to apply, this method returns empty list and nil error.
|
||||
//
|
||||
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
|
||||
// with a requested version of 10, only versions 9,10 will be applied.
|
||||
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
if version < 1 {
|
||||
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
|
||||
}
|
||||
return p.up(ctx, false, version)
|
||||
}
|
||||
|
||||
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
|
||||
// method returns [ErrNoNextVersion].
|
||||
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
||||
res, err := p.down(ctx, true, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
if len(res) > 1 {
|
||||
return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res))
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// DownTo rolls back all migrations down to, but not including, the specified version.
|
||||
//
|
||||
// For instance, if the current database version is 11,10,9... and the requested version is 9, only
|
||||
// migrations 11, 10 will be rolled back.
|
||||
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
if version < 0 {
|
||||
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
|
||||
}
|
||||
return p.down(ctx, false, version)
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
package provider_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/provider"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||
check.NoError(t, err)
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
_, err := provider.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
|
||||
})
|
||||
|
||||
mapFS := fstest.MapFS{
|
||||
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
|
||||
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
||||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys)
|
||||
check.NoError(t, err)
|
||||
sources := p.ListSources()
|
||||
check.Equal(t, len(sources), 2)
|
||||
check.Equal(t, sources[0], newSource(provider.TypeSQL, "001_foo.sql", 1))
|
||||
check.Equal(t, sources[1], newSource(provider.TypeSQL, "002_bar.sql", 2))
|
||||
|
||||
t.Run("duplicate_go", func(t *testing.T) {
|
||||
// Not parallel because it modifies global state.
|
||||
register := []*provider.MigrationCopy{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnContext: nil,
|
||||
DownFnContext: nil,
|
||||
},
|
||||
}
|
||||
err := provider.SetGlobalGoMigrations(register)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
|
||||
db := newDB(t)
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
|
||||
provider.WithGoMigration(1, nil, nil),
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "go migration with version 1 already registered")
|
||||
})
|
||||
t.Run("empty_go", func(t *testing.T) {
|
||||
db := newDB(t)
|
||||
// explicit
|
||||
_, err := provider.NewProvider(database.DialectSQLite3, db, nil,
|
||||
provider.WithGoMigration(1, &provider.GoMigrationFunc{Run: nil}, &provider.GoMigrationFunc{Run: nil}),
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "go migration with version 1 must have an up function")
|
||||
})
|
||||
t.Run("duplicate_up", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
|
||||
},
|
||||
})
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
})
|
||||
t.Run("duplicate_down", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
|
||||
},
|
||||
})
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
})
|
||||
t.Run("not_registered", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go",
|
||||
},
|
||||
})
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "migration must be registered")
|
||||
})
|
||||
t.Run("zero_not_allowed", func(t *testing.T) {
|
||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
||||
{
|
||||
Version: 0,
|
||||
},
|
||||
})
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "migration versions must be greater than zero")
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
migration1 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE foo (id INTEGER PRIMARY KEY);
|
||||
-- +goose Down
|
||||
DROP TABLE foo;
|
||||
`
|
||||
migration2 = `
|
||||
-- +goose Up
|
||||
ALTER TABLE foo ADD COLUMN name TEXT;
|
||||
-- +goose Down
|
||||
ALTER TABLE foo DROP COLUMN name;
|
||||
`
|
||||
migration3 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE bar (
|
||||
id INTEGER PRIMARY KEY,
|
||||
description TEXT
|
||||
);
|
||||
-- +goose Down
|
||||
DROP TABLE bar;
|
||||
`
|
||||
migration4 = `
|
||||
-- +goose Up
|
||||
-- Rename the 'foo' table to 'my_foo'
|
||||
ALTER TABLE foo RENAME TO my_foo;
|
||||
|
||||
-- Add a new column 'timestamp' to 'my_foo'
|
||||
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
|
||||
|
||||
-- +goose Down
|
||||
-- Remove the 'timestamp' column from 'my_foo'
|
||||
ALTER TABLE my_foo DROP COLUMN timestamp;
|
||||
|
||||
-- Rename the 'my_foo' table back to 'foo'
|
||||
ALTER TABLE my_foo RENAME TO foo;
|
||||
`
|
||||
)
|
|
@ -1,516 +0,0 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingZeroVersion = errors.New("missing zero version migration")
|
||||
)
|
||||
|
||||
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
||||
if version < 1 {
|
||||
return nil, errors.New("version must be greater than zero")
|
||||
}
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var apply []*migration
|
||||
if p.cfg.disableVersioning {
|
||||
apply = p.migrations
|
||||
} else {
|
||||
// optimize(mf): Listing all migrations from the database isn't great. This is only required
|
||||
// to support the allow missing (out-of-order) feature. For users that don't use this
|
||||
// feature, we could just query the database for the current max version and then apply
|
||||
// migrations greater than that version.
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbMigrations) == 0 {
|
||||
return nil, errMissingZeroVersion
|
||||
}
|
||||
apply, err = p.resolveUpMigrations(dbMigrations, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
|
||||
// transaction. The default is to apply each migration sequentially on its own.
|
||||
// https://github.com/pressly/goose/issues/222
|
||||
//
|
||||
// Careful, we can't use a single transaction for all migrations because some may have to be run
|
||||
// in their own transaction.
|
||||
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
|
||||
}
|
||||
|
||||
func (p *Provider) resolveUpMigrations(
|
||||
dbVersions []*database.ListMigrationsResult,
|
||||
version int64,
|
||||
) ([]*migration, error) {
|
||||
var apply []*migration
|
||||
var dbMaxVersion int64
|
||||
// dbAppliedVersions is a map of all applied migrations in the database.
|
||||
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
|
||||
for _, m := range dbVersions {
|
||||
dbAppliedVersions[m.Version] = true
|
||||
if m.Version > dbMaxVersion {
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
|
||||
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
|
||||
// migrations entirely. At the moment this is not supported, but leaving this comment because
|
||||
// that's where that logic would be handled.
|
||||
//
|
||||
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
|
||||
// sure if this is a common use case, but it's possible.
|
||||
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
||||
var collected []string
|
||||
for _, v := range missingMigrations {
|
||||
collected = append(collected, v.filename)
|
||||
}
|
||||
msg := "migration"
|
||||
if len(collected) > 1 {
|
||||
msg += "s"
|
||||
}
|
||||
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
|
||||
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
|
||||
)
|
||||
}
|
||||
for _, v := range missingMigrations {
|
||||
m, err := p.getMigration(v.versionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apply = append(apply, m)
|
||||
}
|
||||
// filter all migrations with a version greater than the supplied version (min) and less than or
|
||||
// equal to the requested version (max). Skip any migrations that have already been applied.
|
||||
for _, m := range p.migrations {
|
||||
if dbAppliedVersions[m.Source.Version] {
|
||||
continue
|
||||
}
|
||||
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
|
||||
apply = append(apply, m)
|
||||
}
|
||||
}
|
||||
return apply, nil
|
||||
}
|
||||
|
||||
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if p.cfg.disableVersioning {
|
||||
downMigrations := p.migrations
|
||||
if downByOne {
|
||||
last := p.migrations[len(p.migrations)-1]
|
||||
downMigrations = []*migration{last}
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
||||
}
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbMigrations) == 0 {
|
||||
return nil, errMissingZeroVersion
|
||||
}
|
||||
if dbMigrations[0].Version == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var downMigrations []*migration
|
||||
for _, dbMigration := range dbMigrations {
|
||||
if dbMigration.Version <= version {
|
||||
break
|
||||
}
|
||||
m, err := p.getMigration(dbMigration.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
downMigrations = append(downMigrations, m)
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
||||
}
|
||||
|
||||
// runMigrations runs migrations sequentially in the given direction.
|
||||
//
|
||||
// If the migrations list is empty, return nil without error.
|
||||
func (p *Provider) runMigrations(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
migrations []*migration,
|
||||
direction sqlparser.Direction,
|
||||
byOne bool,
|
||||
) ([]*MigrationResult, error) {
|
||||
if len(migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
apply := migrations
|
||||
if byOne {
|
||||
apply = migrations[:1]
|
||||
}
|
||||
// Lazily parse SQL migrations (if any) in both directions. We do this before running any
|
||||
// migrations so that we can fail fast if there are any errors and avoid leaving the database in
|
||||
// a partially migrated state.
|
||||
if err := parseSQL(p.fsys, false, apply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
|
||||
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
|
||||
// to run in a transaction.
|
||||
|
||||
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but
|
||||
// are locking the database with *sql.Conn. If the caller sets max open connections to 1, then
|
||||
// this will deadlock because the Go migration will try to acquire a connection from the pool,
|
||||
// but the pool is locked.
|
||||
//
|
||||
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to use
|
||||
// *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of
|
||||
// an edge case.
|
||||
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
|
||||
for _, m := range apply {
|
||||
switch m.Source.Type {
|
||||
case TypeGo:
|
||||
if m.Go != nil && m.useTx(direction.ToBool()) {
|
||||
return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid allocating a slice because we may have a partial migration error.
|
||||
// 1. Avoid giving the impression that N migrations were applied when in fact some were not
|
||||
// 2. Avoid the caller having to check for nil results
|
||||
var results []*MigrationResult
|
||||
for _, m := range apply {
|
||||
current := &MigrationResult{
|
||||
Source: m.Source,
|
||||
Direction: direction.String(),
|
||||
Empty: m.isEmpty(direction.ToBool()),
|
||||
}
|
||||
start := time.Now()
|
||||
if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil {
|
||||
// TODO(mf): we should also return the pending migrations here, the remaining items in
|
||||
// the apply slice.
|
||||
current.Error = err
|
||||
current.Duration = time.Since(start)
|
||||
return nil, &PartialError{
|
||||
Applied: results,
|
||||
Failed: current,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
current.Duration = time.Since(start)
|
||||
results = append(results, current)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (p *Provider) runIndividually(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
direction bool,
|
||||
m *migration,
|
||||
) error {
|
||||
if m.useTx(direction) {
|
||||
// Run the migration in a transaction.
|
||||
return p.beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||
if err := m.run(ctx, tx, direction); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.cfg.disableVersioning {
|
||||
return nil
|
||||
}
|
||||
if direction {
|
||||
return p.store.Insert(ctx, tx, database.InsertRequest{Version: m.Source.Version})
|
||||
}
|
||||
return p.store.Delete(ctx, tx, m.Source.Version)
|
||||
})
|
||||
}
|
||||
// Run the migration outside of a transaction.
|
||||
switch m.Source.Type {
|
||||
case TypeGo:
|
||||
// Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the
|
||||
// GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open
|
||||
// connections to 1. See the comment in runMigrations for more details.
|
||||
if err := m.runNoTx(ctx, p.db, direction); err != nil {
|
||||
return err
|
||||
}
|
||||
case TypeSQL:
|
||||
if err := m.runConn(ctx, conn, direction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if p.cfg.disableVersioning {
|
||||
return nil
|
||||
}
|
||||
if direction {
|
||||
return p.store.Insert(ctx, conn, database.InsertRequest{Version: m.Source.Version})
|
||||
}
|
||||
return p.store.Delete(ctx, conn, m.Source.Version)
|
||||
}
|
||||
|
||||
// beginTx begins a transaction and runs the given function. If the function returns an error, the
|
||||
// transaction is rolled back. Otherwise, the transaction is committed.
|
||||
//
|
||||
// If the provider is configured to use versioning, this function also inserts or deletes the
|
||||
// migration version.
|
||||
func (p *Provider) beginTx(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
fn func(tx *sql.Tx) error,
|
||||
) (retErr error) {
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
retErr = multierr.Append(retErr, tx.Rollback())
|
||||
}
|
||||
}()
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
|
||||
p.mu.Lock()
|
||||
conn, err := p.db.Conn(ctx)
|
||||
if err != nil {
|
||||
p.mu.Unlock()
|
||||
return nil, nil, err
|
||||
}
|
||||
// cleanup is a function that cleans up the connection, and optionally, the session lock.
|
||||
cleanup := func() error {
|
||||
p.mu.Unlock()
|
||||
return conn.Close()
|
||||
}
|
||||
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
|
||||
if err := l.SessionLock(ctx, conn); err != nil {
|
||||
return nil, nil, multierr.Append(err, cleanup())
|
||||
}
|
||||
cleanup = func() error {
|
||||
p.mu.Unlock()
|
||||
// Use a detached context to unlock the session. This is because the context passed to
|
||||
// SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf):
|
||||
// use [context.WithoutCancel] added in go1.21
|
||||
detachedCtx := context.Background()
|
||||
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
|
||||
}
|
||||
}
|
||||
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
|
||||
// need the version table because there is no versioning.
|
||||
if !p.cfg.disableVersioning {
|
||||
if err := p.ensureVersionTable(ctx, conn); err != nil {
|
||||
return nil, nil, multierr.Append(err, cleanup())
|
||||
}
|
||||
}
|
||||
return conn, cleanup, nil
|
||||
}
|
||||
|
||||
// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
|
||||
// will not be parsed again.
|
||||
//
|
||||
// Important: This function will mutate SQL migrations and is not safe for concurrent use.
|
||||
func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
|
||||
for _, m := range migrations {
|
||||
// If the migration is a SQL migration, and it has not been parsed, parse it.
|
||||
if m.Source.Type == TypeSQL && m.SQL == nil {
|
||||
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.SQL = &sqlMigration{
|
||||
UseTx: parsed.UseTx,
|
||||
UpStatements: parsed.Up,
|
||||
DownStatements: parsed.Down,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
|
||||
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
|
||||
// from a table that may not exist. https://github.com/pressly/goose/issues/461
|
||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||
if err == nil && res != nil {
|
||||
return nil
|
||||
}
|
||||
return p.beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.cfg.disableVersioning {
|
||||
return nil
|
||||
}
|
||||
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
|
||||
})
|
||||
}
|
||||
|
||||
type missingMigration struct {
|
||||
versionID int64
|
||||
filename string
|
||||
}
|
||||
|
||||
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
|
||||
// migration is one that has a version less than the max version in the database.
|
||||
func checkMissingMigrations(
|
||||
dbMigrations []*database.ListMigrationsResult,
|
||||
fsMigrations []*migration,
|
||||
) []missingMigration {
|
||||
existing := make(map[int64]bool)
|
||||
var dbMaxVersion int64
|
||||
for _, m := range dbMigrations {
|
||||
existing[m.Version] = true
|
||||
if m.Version > dbMaxVersion {
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
var missing []missingMigration
|
||||
for _, m := range fsMigrations {
|
||||
version := m.Source.Version
|
||||
if !existing[version] && version < dbMaxVersion {
|
||||
missing = append(missing, missingMigration{
|
||||
versionID: version,
|
||||
filename: m.filename(),
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Slice(missing, func(i, j int) bool {
|
||||
return missing[i].versionID < missing[j].versionID
|
||||
})
|
||||
return missing
|
||||
}
|
||||
|
||||
// getMigration returns the migration with the given version. If no migration is found, then
|
||||
// ErrVersionNotFound is returned.
|
||||
func (p *Provider) getMigration(version int64) (*migration, error) {
|
||||
for _, m := range p.migrations {
|
||||
if m.Source.Version == version {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrVersionNotFound
|
||||
}
|
||||
|
||||
func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
|
||||
m, err := p.getMigration(version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
result, err := p.store.GetMigration(ctx, conn, version)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
// If the migration has already been applied, return an error, unless the migration is being
|
||||
// applied in the opposite direction. In that case, we allow the migration to be applied again.
|
||||
if result != nil && direction {
|
||||
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
|
||||
}
|
||||
|
||||
d := sqlparser.DirectionDown
|
||||
if direction {
|
||||
d = sqlparser.DirectionUp
|
||||
}
|
||||
results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
// TODO(mf): add support for limit and order. Also would be nice to refactor the list query to
|
||||
// support limiting the set.
|
||||
|
||||
status := make([]*MigrationStatus, 0, len(p.migrations))
|
||||
for _, m := range p.migrations {
|
||||
migrationStatus := &MigrationStatus{
|
||||
Source: m.Source,
|
||||
State: StatePending,
|
||||
}
|
||||
dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
if dbResult != nil {
|
||||
migrationStatus.State = StateApplied
|
||||
migrationStatus.AppliedAt = dbResult.Timestamp
|
||||
}
|
||||
status = append(status, migrationStatus)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
res, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
sort.Slice(res, func(i, j int) bool {
|
||||
return res[i].Version > res[j].Version
|
||||
})
|
||||
return res[0].Version, nil
|
||||
}
|
|
@ -13,17 +13,25 @@ import (
|
|||
// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive
|
||||
// session-level advisory lock mechanism.
|
||||
//
|
||||
// This function creates a SessionLocker that can be used to acquire and release locks for
|
||||
// This function creates a SessionLocker that can be used to acquire and release a lock for
|
||||
// synchronization purposes. The lock acquisition is retried until it is successfully acquired or
|
||||
// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the
|
||||
// until the failure threshold is reached. The default lock duration is set to 5 minutes, and the
|
||||
// default unlock duration is set to 1 minute.
|
||||
//
|
||||
// If you have long running migrations, you may want to increase the lock duration.
|
||||
//
|
||||
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
|
||||
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
|
||||
cfg := sessionLockerConfig{
|
||||
lockID: DefaultLockID,
|
||||
lockTimeout: DefaultLockTimeout,
|
||||
unlockTimeout: DefaultUnlockTimeout,
|
||||
lockID: DefaultLockID,
|
||||
lockProbe: probe{
|
||||
periodSeconds: 5 * time.Second,
|
||||
failureThreshold: 60,
|
||||
},
|
||||
unlockProbe: probe{
|
||||
periodSeconds: 2 * time.Second,
|
||||
failureThreshold: 30,
|
||||
},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
|
@ -32,13 +40,13 @@ func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error
|
|||
}
|
||||
return &postgresSessionLocker{
|
||||
lockID: cfg.lockID,
|
||||
retryLock: retry.WithMaxDuration(
|
||||
cfg.lockTimeout,
|
||||
retry.NewConstant(2*time.Second),
|
||||
retryLock: retry.WithMaxRetries(
|
||||
cfg.lockProbe.failureThreshold,
|
||||
retry.NewConstant(cfg.lockProbe.periodSeconds),
|
||||
),
|
||||
retryUnlock: retry.WithMaxDuration(
|
||||
cfg.unlockTimeout,
|
||||
retry.NewConstant(2*time.Second),
|
||||
retryUnlock: retry.WithMaxRetries(
|
||||
cfg.unlockProbe.failureThreshold,
|
||||
retry.NewConstant(cfg.unlockProbe.periodSeconds),
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/testdb"
|
||||
|
@ -30,8 +29,8 @@ func TestPostgresSessionLocker(t *testing.T) {
|
|||
)
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockID(lockID),
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||
)
|
||||
check.NoError(t, err)
|
||||
ctx := context.Background()
|
||||
|
@ -60,8 +59,8 @@ func TestPostgresSessionLocker(t *testing.T) {
|
|||
})
|
||||
t.Run("lock_close_conn_unlock", func(t *testing.T) {
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||
)
|
||||
check.NoError(t, err)
|
||||
ctx := context.Background()
|
||||
|
@ -103,10 +102,12 @@ func TestPostgresSessionLocker(t *testing.T) {
|
|||
// Exactly one connection should acquire the lock. While the other connections
|
||||
// should fail to acquire the lock and timeout.
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||
)
|
||||
check.NoError(t, err)
|
||||
// NOTE, we are not unlocking the lock, because we want to test that the lock is
|
||||
// released when the connection is closed.
|
||||
ch <- locker.SessionLock(ctx, conn)
|
||||
}()
|
||||
}
|
||||
|
@ -138,8 +139,8 @@ func TestPostgresSessionLocker(t *testing.T) {
|
|||
)
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockID(lockID),
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||
)
|
||||
check.NoError(t, err)
|
||||
|
||||
|
@ -179,6 +180,7 @@ func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var pgLocks []pgLock
|
||||
for rows.Next() {
|
||||
var p pgLock
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package lock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -10,11 +11,6 @@ const (
|
|||
//
|
||||
// crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA))
|
||||
DefaultLockID int64 = 5887940537704921958
|
||||
|
||||
// Default values for the lock (time to wait for the lock to be acquired) and unlock (time to
|
||||
// wait for the lock to be released) wait durations.
|
||||
DefaultLockTimeout time.Duration = 60 * time.Minute
|
||||
DefaultUnlockTimeout time.Duration = 1 * time.Minute
|
||||
)
|
||||
|
||||
// SessionLockerOption is used to configure a SessionLocker.
|
||||
|
@ -32,26 +28,65 @@ func WithLockID(lockID int64) SessionLockerOption {
|
|||
})
|
||||
}
|
||||
|
||||
// WithLockTimeout sets the max duration to wait for the lock to be acquired.
|
||||
func WithLockTimeout(duration time.Duration) SessionLockerOption {
|
||||
// WithLockTimeout sets the max duration to wait for the lock to be acquired. The total duration
|
||||
// will be the period times the failure threshold.
|
||||
//
|
||||
// By default, the lock timeout is 300s (5min), where the lock is retried every 5 seconds (period)
|
||||
// up to 60 times (failure threshold).
|
||||
//
|
||||
// The minimum period is 1 second, and the minimum failure threshold is 1.
|
||||
func WithLockTimeout(period, failureThreshold uint64) SessionLockerOption {
|
||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||
c.lockTimeout = duration
|
||||
if period < 1 {
|
||||
return errors.New("period must be greater than 0, minimum is 1")
|
||||
}
|
||||
if failureThreshold < 1 {
|
||||
return errors.New("failure threshold must be greater than 0, minimum is 1")
|
||||
}
|
||||
c.lockProbe = probe{
|
||||
periodSeconds: time.Duration(period) * time.Second,
|
||||
failureThreshold: failureThreshold,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithUnlockTimeout sets the max duration to wait for the lock to be released.
|
||||
func WithUnlockTimeout(duration time.Duration) SessionLockerOption {
|
||||
// WithUnlockTimeout sets the max duration to wait for the lock to be released. The total duration
|
||||
// will be the period times the failure threshold.
|
||||
//
|
||||
// By default, the lock timeout is 60s, where the lock is retried every 2 seconds (period) up to 30
|
||||
// times (failure threshold).
|
||||
//
|
||||
// The minimum period is 1 second, and the minimum failure threshold is 1.
|
||||
func WithUnlockTimeout(period, failureThreshold uint64) SessionLockerOption {
|
||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||
c.unlockTimeout = duration
|
||||
if period < 1 {
|
||||
return errors.New("period must be greater than 0, minimum is 1")
|
||||
}
|
||||
if failureThreshold < 1 {
|
||||
return errors.New("failure threshold must be greater than 0, minimum is 1")
|
||||
}
|
||||
c.unlockProbe = probe{
|
||||
periodSeconds: time.Duration(period) * time.Second,
|
||||
failureThreshold: failureThreshold,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type sessionLockerConfig struct {
|
||||
lockID int64
|
||||
lockTimeout time.Duration
|
||||
unlockTimeout time.Duration
|
||||
lockID int64
|
||||
lockProbe probe
|
||||
unlockProbe probe
|
||||
}
|
||||
|
||||
// probe is used to configure how often and how many times to retry a lock or unlock operation. The
|
||||
// total timeout will be the period times the failure threshold.
|
||||
type probe struct {
|
||||
// How often (in seconds) to perform the probe.
|
||||
periodSeconds time.Duration
|
||||
// Number of times to retry the probe.
|
||||
failureThreshold uint64
|
||||
}
|
||||
|
||||
var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)
|
||||
|
|
54
migration.go
54
migration.go
|
@ -18,22 +18,36 @@ import (
|
|||
// Both up and down functions may be nil, in which case the migration will be recorded in the
|
||||
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
|
||||
// a version without running any functions. See [GoFunc] for more details.
|
||||
func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
||||
m := Migration{
|
||||
func NewGoMigration(version int64, up, down *GoFunc) *Migration {
|
||||
m := &Migration{
|
||||
Type: TypeGo,
|
||||
Registered: true,
|
||||
Version: version,
|
||||
Next: -1, Previous: -1,
|
||||
goUp: up,
|
||||
goDown: down,
|
||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||
construct: true,
|
||||
}
|
||||
updateMode := func(f *GoFunc) *GoFunc {
|
||||
// infer mode from function
|
||||
if f.Mode == 0 {
|
||||
if f.RunTx != nil && f.RunDB == nil {
|
||||
f.Mode = TransactionEnabled
|
||||
}
|
||||
if f.RunTx == nil && f.RunDB != nil {
|
||||
f.Mode = TransactionDisabled
|
||||
}
|
||||
}
|
||||
return f
|
||||
}
|
||||
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
|
||||
// we will remove these fields in favor of [GoFunc].
|
||||
//
|
||||
// Note, this function does not do any validation. Validation is lazily done when the migration
|
||||
// is registered.
|
||||
if up != nil {
|
||||
m.goUp = updateMode(up)
|
||||
|
||||
if up.RunDB != nil {
|
||||
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
|
||||
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
|
||||
|
@ -45,6 +59,8 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
|||
}
|
||||
}
|
||||
if down != nil {
|
||||
m.goDown = updateMode(down)
|
||||
|
||||
if down.RunDB != nil {
|
||||
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
|
||||
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
|
||||
|
@ -55,12 +71,6 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
|||
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
|
||||
}
|
||||
}
|
||||
if m.goUp == nil {
|
||||
m.goUp = &GoFunc{Mode: TransactionEnabled}
|
||||
}
|
||||
if m.goDown == nil {
|
||||
m.goDown = &GoFunc{Mode: TransactionEnabled}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
|
@ -76,10 +86,6 @@ type Migration struct {
|
|||
|
||||
UpFnContext, DownFnContext GoMigrationContext
|
||||
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
|
||||
// These fields are used internally by goose and users are not expected to set them. Instead,
|
||||
// use [NewGoMigration] to create a new go migration.
|
||||
construct bool
|
||||
goUp, goDown *GoFunc
|
||||
|
||||
// These fields will be removed in a future major version. They are here for backwards
|
||||
// compatibility and are an implementation detail.
|
||||
|
@ -98,6 +104,26 @@ type Migration struct {
|
|||
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
|
||||
|
||||
noVersioning bool
|
||||
|
||||
// These fields are used internally by goose and users are not expected to set them. Instead,
|
||||
// use [NewGoMigration] to create a new go migration.
|
||||
construct bool
|
||||
goUp, goDown *GoFunc
|
||||
|
||||
sql sqlMigration
|
||||
}
|
||||
|
||||
type sqlMigration struct {
|
||||
// The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
|
||||
// optimization to avoid parsing migrations that may never be needed. Typically, migrations are
|
||||
// incremental, and users often run only the most recent ones, making parsing of prior
|
||||
// migrations unnecessary in most cases.
|
||||
Parsed bool
|
||||
|
||||
// Parsed must be set to true before the following fields are used.
|
||||
UseTx bool
|
||||
Up []string
|
||||
Down []string
|
||||
}
|
||||
|
||||
// GoFunc represents a Go migration function.
|
||||
|
|
8
osfs.go
8
osfs.go
|
@ -18,3 +18,11 @@ func (osFS) Stat(name string) (fs.FileInfo, error) { return os.Stat(filepath.Fro
|
|||
func (osFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(filepath.FromSlash(name)) }
|
||||
|
||||
func (osFS) Glob(pattern string) ([]string, error) { return filepath.Glob(filepath.FromSlash(pattern)) }
|
||||
|
||||
type noopFS struct{}
|
||||
|
||||
var _ fs.FS = noopFS{}
|
||||
|
||||
func (f noopFS) Open(name string) (fs.File, error) {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
|
|
@ -0,0 +1,477 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
// Provider is a goose migration provider.
|
||||
type Provider struct {
|
||||
// mu protects all accesses to the provider and must be held when calling operations on the
|
||||
// database.
|
||||
mu sync.Mutex
|
||||
|
||||
db *sql.DB
|
||||
store database.Store
|
||||
|
||||
fsys fs.FS
|
||||
cfg config
|
||||
|
||||
// migrations are ordered by version in ascending order.
|
||||
migrations []*Migration
|
||||
}
|
||||
|
||||
// NewProvider returns a new goose provider.
|
||||
//
|
||||
// The caller is responsible for matching the database dialect with the database/sql driver. For
|
||||
// example, if the database dialect is "postgres", the database/sql driver could be
|
||||
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
|
||||
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
|
||||
// as using a custom table name or supplying a custom store implementation, see [WithStore].
|
||||
//
|
||||
// fsys is the filesystem used to read migration files, but may be nil. Most users will want to use
|
||||
// [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
|
||||
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
|
||||
// migrations using [fs.Sub].
|
||||
//
|
||||
// See [ProviderOption] for more information on configuring the provider.
|
||||
//
|
||||
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
|
||||
//
|
||||
// Experimental: This API is experimental and may change in the future.
|
||||
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("db must not be nil")
|
||||
}
|
||||
if fsys == nil {
|
||||
fsys = noopFS{}
|
||||
}
|
||||
cfg := config{
|
||||
registered: make(map[int64]*Migration),
|
||||
excludePaths: make(map[string]bool),
|
||||
excludeVersions: make(map[int64]bool),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Allow users to specify a custom store implementation, but only if they don't specify a
|
||||
// dialect. If they specify a dialect, we'll use the default store implementation.
|
||||
if dialect == "" && cfg.store == nil {
|
||||
return nil, errors.New("dialect must not be empty")
|
||||
}
|
||||
if dialect != "" && cfg.store != nil {
|
||||
return nil, errors.New("dialect must be empty when using a custom store implementation")
|
||||
}
|
||||
var store database.Store
|
||||
if dialect != "" {
|
||||
var err error
|
||||
store, err = database.NewStore(dialect, DefaultTablename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
store = cfg.store
|
||||
}
|
||||
if store.Tablename() == "" {
|
||||
return nil, errors.New("invalid store implementation: table name must not be empty")
|
||||
}
|
||||
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
|
||||
}
|
||||
|
||||
func newProvider(
|
||||
db *sql.DB,
|
||||
store database.Store,
|
||||
fsys fs.FS,
|
||||
cfg config,
|
||||
global map[int64]*Migration,
|
||||
) (*Provider, error) {
|
||||
// Collect migrations from the filesystem and merge with registered migrations.
|
||||
//
|
||||
// Note, we don't parse SQL migrations here. They are parsed lazily when required.
|
||||
|
||||
// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
|
||||
// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
|
||||
// we should make it optional.
|
||||
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
versionToGoMigration := make(map[int64]*Migration)
|
||||
// Add user-registered Go migrations from the provider.
|
||||
for version, m := range cfg.registered {
|
||||
versionToGoMigration[version] = m
|
||||
}
|
||||
// Add globally registered Go migrations.
|
||||
for version, m := range global {
|
||||
if _, ok := versionToGoMigration[version]; ok {
|
||||
return nil, fmt.Errorf("global go migration with version %d previously registered with provider", version)
|
||||
}
|
||||
versionToGoMigration[version] = m
|
||||
}
|
||||
// At this point we have all registered unique Go migrations (if any). We need to merge them
|
||||
// with SQL migrations from the filesystem.
|
||||
migrations, err := merge(filesystemSources, versionToGoMigration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(migrations) == 0 {
|
||||
return nil, ErrNoMigrations
|
||||
}
|
||||
return &Provider{
|
||||
db: db,
|
||||
fsys: fsys,
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
migrations: migrations,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Status returns the status of all migrations, merging the list of migrations from the database and
|
||||
// filesystem. The returned items are ordered by version, in ascending order.
|
||||
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
||||
return p.status(ctx)
|
||||
}
|
||||
|
||||
// GetDBVersion returns the highest version recorded in the database, regardless of the order in
|
||||
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
|
||||
// this method returns 4. If no migrations have been applied, it returns 0.
|
||||
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
||||
return p.getDBMaxVersion(ctx)
|
||||
}
|
||||
|
||||
// ListSources returns a list of all migration sources known to the provider, sorted in ascending
|
||||
// order by version. The path field may be empty for manually registered migrations, such as Go
|
||||
// migrations registered using the [WithGoMigrations] option.
|
||||
func (p *Provider) ListSources() []*Source {
|
||||
sources := make([]*Source, 0, len(p.migrations))
|
||||
for _, m := range p.migrations {
|
||||
sources = append(sources, &Source{
|
||||
Type: m.Type,
|
||||
Path: m.Source,
|
||||
Version: m.Version,
|
||||
})
|
||||
}
|
||||
return sources
|
||||
}
|
||||
|
||||
// Ping attempts to ping the database to verify a connection is available.
|
||||
func (p *Provider) Ping(ctx context.Context) error {
|
||||
return p.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// Close closes the database connection initially supplied to the provider.
|
||||
func (p *Provider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// ApplyVersion applies exactly one migration for the specified version. If there is no migration
|
||||
// available for the specified version, this method returns [ErrVersionNotFound]. If the migration
|
||||
// has already been applied, this method returns [ErrAlreadyApplied].
|
||||
//
|
||||
// The direction parameter determines the migration direction: true for up migration and false for
|
||||
// down migration.
|
||||
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
|
||||
res, err := p.apply(ctx, version, direction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// This should never happen, we must return exactly one result.
|
||||
if len(res) != 1 {
|
||||
versions := make([]string, 0, len(res))
|
||||
for _, r := range res {
|
||||
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected number of migrations applied running apply, expecting exactly one result: %v",
|
||||
strings.Join(versions, ","),
|
||||
)
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
|
||||
// empty list and nil error.
|
||||
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
|
||||
return p.up(ctx, false, math.MaxInt64)
|
||||
}
|
||||
|
||||
// UpByOne applies the next pending migration. If there is no next migration to apply, this method
|
||||
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
|
||||
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
||||
res, err := p.up(ctx, true, math.MaxInt64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
// This should never happen, we must return exactly one result.
|
||||
if len(res) != 1 {
|
||||
versions := make([]string, 0, len(res))
|
||||
for _, r := range res {
|
||||
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected number of migrations applied running up-by-one, expecting exactly one result: %v",
|
||||
strings.Join(versions, ","),
|
||||
)
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// UpTo applies all pending migrations up to, and including, the specified version. If there are no
|
||||
// migrations to apply, this method returns empty list and nil error.
|
||||
//
|
||||
// For example, if there are three new migrations (9,10,11) and the current database version is 8
|
||||
// with a requested version of 10, only versions 9,10 will be applied.
|
||||
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
return p.up(ctx, false, version)
|
||||
}
|
||||
|
||||
// Down rolls back the most recently applied migration. If there are no migrations to rollback, this
|
||||
// method returns [ErrNoNextVersion].
|
||||
//
|
||||
// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
|
||||
// the migration version. This only applies in scenarios where migrations are allowed to be applied
|
||||
// out of order.
|
||||
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
||||
res, err := p.down(ctx, true, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
// This should never happen, we must return exactly one result.
|
||||
if len(res) != 1 {
|
||||
versions := make([]string, 0, len(res))
|
||||
for _, r := range res {
|
||||
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected number of migrations applied running down, expecting exactly one result: %v",
|
||||
strings.Join(versions, ","),
|
||||
)
|
||||
}
|
||||
return res[0], nil
|
||||
}
|
||||
|
||||
// DownTo rolls back all migrations down to, but not including, the specified version.
|
||||
//
|
||||
// For example, if the current database version is 11,10,9... and the requested version is 9, only
|
||||
// migrations 11, 10 will be rolled back.
|
||||
//
|
||||
// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
|
||||
// the migration version. This only applies in scenarios where migrations are allowed to be applied
|
||||
// out of order.
|
||||
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||
if version < 0 {
|
||||
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
|
||||
}
|
||||
return p.down(ctx, false, version)
|
||||
}
|
||||
|
||||
// *** Internal methods ***
|
||||
|
||||
func (p *Provider) up(
|
||||
ctx context.Context,
|
||||
byOne bool,
|
||||
version int64,
|
||||
) (_ []*MigrationResult, retErr error) {
|
||||
if version < 1 {
|
||||
return nil, errInvalidVersion
|
||||
}
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var apply []*Migration
|
||||
if p.cfg.disableVersioning {
|
||||
if byOne {
|
||||
return nil, errors.New("up-by-one not supported when versioning is disabled")
|
||||
}
|
||||
apply = p.migrations
|
||||
} else {
|
||||
// optimize(mf): Listing all migrations from the database isn't great. This is only required
|
||||
// to support the allow missing (out-of-order) feature. For users that don't use this
|
||||
// feature, we could just query the database for the current max version and then apply
|
||||
// migrations greater than that version.
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbMigrations) == 0 {
|
||||
return nil, errMissingZeroVersion
|
||||
}
|
||||
apply, err = p.resolveUpMigrations(dbMigrations, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne)
|
||||
}
|
||||
|
||||
func (p *Provider) down(
|
||||
ctx context.Context,
|
||||
byOne bool,
|
||||
version int64,
|
||||
) (_ []*MigrationResult, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
if len(p.migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if p.cfg.disableVersioning {
|
||||
var downMigrations []*Migration
|
||||
if byOne {
|
||||
last := p.migrations[len(p.migrations)-1]
|
||||
downMigrations = []*Migration{last}
|
||||
} else {
|
||||
downMigrations = p.migrations
|
||||
}
|
||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne)
|
||||
}
|
||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dbMigrations) == 0 {
|
||||
return nil, errMissingZeroVersion
|
||||
}
|
||||
// We never migrate the zero version down.
|
||||
if dbMigrations[0].Version == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var apply []*Migration
|
||||
for _, dbMigration := range dbMigrations {
|
||||
if dbMigration.Version <= version {
|
||||
break
|
||||
}
|
||||
m, err := p.getMigration(dbMigration.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apply = append(apply, m)
|
||||
}
|
||||
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne)
|
||||
}
|
||||
|
||||
func (p *Provider) apply(
|
||||
ctx context.Context,
|
||||
version int64,
|
||||
direction bool,
|
||||
) (_ []*MigrationResult, retErr error) {
|
||||
if version < 1 {
|
||||
return nil, errInvalidVersion
|
||||
}
|
||||
m, err := p.getMigration(version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
result, err := p.store.GetMigration(ctx, conn, version)
|
||||
if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
// If the migration has already been applied, return an error. But, if the migration is being
|
||||
// rolled back, we allow the individual migration to be applied again.
|
||||
if result != nil && direction {
|
||||
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
|
||||
}
|
||||
d := sqlparser.DirectionDown
|
||||
if direction {
|
||||
d = sqlparser.DirectionUp
|
||||
}
|
||||
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
|
||||
}
|
||||
|
||||
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
status := make([]*MigrationStatus, 0, len(p.migrations))
|
||||
for _, m := range p.migrations {
|
||||
migrationStatus := &MigrationStatus{
|
||||
Source: &Source{
|
||||
Type: m.Type,
|
||||
Path: m.Source,
|
||||
Version: m.Version,
|
||||
},
|
||||
State: StatePending,
|
||||
}
|
||||
dbResult, err := p.store.GetMigration(ctx, conn, m.Version)
|
||||
if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
if dbResult != nil {
|
||||
migrationStatus.State = StateApplied
|
||||
migrationStatus.AppliedAt = dbResult.Timestamp
|
||||
}
|
||||
status = append(status, migrationStatus)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
}()
|
||||
|
||||
res, err := p.store.ListMigrations(ctx, conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return 0, errMissingZeroVersion
|
||||
}
|
||||
// Sort in descending order.
|
||||
sort.Slice(res, func(i, j int) bool {
|
||||
return res[i].Version > res[j].Version
|
||||
})
|
||||
// Return the highest version.
|
||||
return res[0].Version, nil
|
||||
}
|
|
@ -1,15 +1,12 @@
|
|||
package provider
|
||||
package goose
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
// fileSources represents a collection of migration files on the filesystem.
|
||||
|
@ -18,25 +15,6 @@ type fileSources struct {
|
|||
goSources []Source
|
||||
}
|
||||
|
||||
// TODO(mf): remove?
|
||||
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
|
||||
switch t {
|
||||
case TypeGo:
|
||||
for _, source := range s.goSources {
|
||||
if source.Version == version {
|
||||
return &source
|
||||
}
|
||||
}
|
||||
case TypeSQL:
|
||||
for _, source := range s.sqlSources {
|
||||
if source.Version == version {
|
||||
return &source
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectFilesystemSources scans the file system for migration files that have a numeric prefix
|
||||
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
|
||||
// be nil, in which case an empty fileSources is returned.
|
||||
|
@ -46,7 +24,12 @@ func (s *fileSources) lookup(t MigrationType, version int64) *Source {
|
|||
//
|
||||
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
|
||||
// migration sources from the filesystem.
|
||||
func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
|
||||
func collectFilesystemSources(
|
||||
fsys fs.FS,
|
||||
strict bool,
|
||||
excludePaths map[string]bool,
|
||||
excludeVersions map[int64]bool,
|
||||
) (*fileSources, error) {
|
||||
if fsys == nil {
|
||||
return new(fileSources), nil
|
||||
}
|
||||
|
@ -62,8 +45,11 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
|||
}
|
||||
for _, fullpath := range files {
|
||||
base := filepath.Base(fullpath)
|
||||
// Skip explicit excludes or Go test files.
|
||||
if excludes[base] || strings.HasSuffix(base, "_test.go") {
|
||||
if strings.HasSuffix(base, "_test.go") {
|
||||
continue
|
||||
}
|
||||
if excludePaths[base] {
|
||||
// TODO(mf): log this?
|
||||
continue
|
||||
}
|
||||
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
|
||||
|
@ -71,13 +57,17 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
|||
// filenames, but still have versioned migrations within the same directory. For
|
||||
// example, a user could have a helpers.go file which contains unexported helper
|
||||
// functions for migrations.
|
||||
version, err := goose.NumericComponent(base)
|
||||
version, err := NumericComponent(base)
|
||||
if err != nil {
|
||||
if strict {
|
||||
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if excludeVersions[version] {
|
||||
// TODO: log this?
|
||||
continue
|
||||
}
|
||||
// Ensure there are no duplicate versions.
|
||||
if existing, ok := versionToBaseLookup[version]; ok {
|
||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||
|
@ -101,7 +91,7 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
|||
})
|
||||
default:
|
||||
// Should never happen since we already filtered out all other file types.
|
||||
return nil, fmt.Errorf("unknown migration type: %s", base)
|
||||
return nil, fmt.Errorf("invalid file extension: %q", base)
|
||||
}
|
||||
// Add the version to the lookup map.
|
||||
versionToBaseLookup[version] = base
|
||||
|
@ -110,15 +100,25 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
|||
return sources, nil
|
||||
}
|
||||
|
||||
func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) {
|
||||
var migrations []*migration
|
||||
migrationLookup := make(map[int64]*migration)
|
||||
func newSQLMigration(source Source) *Migration {
|
||||
return &Migration{
|
||||
Type: source.Type,
|
||||
Version: source.Version,
|
||||
Source: source.Path,
|
||||
construct: true,
|
||||
Next: -1, Previous: -1,
|
||||
sql: sqlMigration{
|
||||
Parsed: false, // SQL migrations are parsed lazily.
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func merge(sources *fileSources, registerd map[int64]*Migration) ([]*Migration, error) {
|
||||
var migrations []*Migration
|
||||
migrationLookup := make(map[int64]*Migration)
|
||||
// Add all SQL migrations to the list of migrations.
|
||||
for _, source := range sources.sqlSources {
|
||||
m := &migration{
|
||||
Source: source,
|
||||
SQL: nil, // SQL migrations are parsed lazily.
|
||||
}
|
||||
m := newSQLMigration(source)
|
||||
migrations = append(migrations, m)
|
||||
migrationLookup[source.Version] = m
|
||||
}
|
||||
|
@ -147,38 +147,24 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
|
|||
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
|
||||
// the SQL migration files.
|
||||
for version, r := range registerd {
|
||||
fullpath := r.fullpath
|
||||
if fullpath == "" {
|
||||
if s := sources.lookup(TypeGo, version); s != nil {
|
||||
fullpath = s.Path
|
||||
}
|
||||
}
|
||||
// Ensure there are no duplicate versions.
|
||||
if existing, ok := migrationLookup[version]; ok {
|
||||
fullpath := r.fullpath
|
||||
fullpath := r.Source
|
||||
if fullpath == "" {
|
||||
fullpath = "manually registered (no source)"
|
||||
}
|
||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||
version,
|
||||
existing.Source.Path,
|
||||
existing.Source,
|
||||
fullpath,
|
||||
)
|
||||
}
|
||||
m := &migration{
|
||||
Source: Source{
|
||||
Type: TypeGo,
|
||||
Path: fullpath, // May be empty if migration was registered manually.
|
||||
Version: version,
|
||||
},
|
||||
Go: r,
|
||||
}
|
||||
migrations = append(migrations, m)
|
||||
migrationLookup[version] = m
|
||||
migrations = append(migrations, r)
|
||||
migrationLookup[version] = r
|
||||
}
|
||||
// Sort migrations by version in ascending order.
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Source.Version < migrations[j].Source.Version
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
})
|
||||
return migrations, nil
|
||||
}
|
||||
|
@ -203,11 +189,3 @@ func unregisteredError(unregistered []string) error {
|
|||
|
||||
return errors.New(b.String())
|
||||
}
|
||||
|
||||
type noopFS struct{}
|
||||
|
||||
var _ fs.FS = noopFS{}
|
||||
|
||||
func (f noopFS) Open(name string) (fs.File, error) {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package provider
|
||||
package goose
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
|
@ -12,21 +12,21 @@ import (
|
|||
func TestCollectFileSources(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("nil_fsys", func(t *testing.T) {
|
||||
sources, err := collectFilesystemSources(nil, false, nil)
|
||||
sources, err := collectFilesystemSources(nil, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Bool(t, sources != nil, true)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
})
|
||||
t.Run("noop_fsys", func(t *testing.T) {
|
||||
sources, err := collectFilesystemSources(noopFS{}, false, nil)
|
||||
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Bool(t, sources != nil, true)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
})
|
||||
t.Run("empty_fsys", func(t *testing.T) {
|
||||
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil)
|
||||
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
|
@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) {
|
|||
"00000_foo.sql": sqlMapFile,
|
||||
}
|
||||
// strict disable - should not error
|
||||
sources, err := collectFilesystemSources(mapFS, false, nil)
|
||||
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
check.Number(t, len(sources.sqlSources), 0)
|
||||
// strict enabled - should error
|
||||
_, err = collectFilesystemSources(mapFS, true, nil)
|
||||
_, err = collectFilesystemSources(mapFS, true, nil, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "migration version must be greater than zero")
|
||||
})
|
||||
t.Run("collect", func(t *testing.T) {
|
||||
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 4)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
|
@ -76,6 +76,7 @@ func TestCollectFileSources(t *testing.T) {
|
|||
"00002_bar.sql": true,
|
||||
"00110_qux.sql": true,
|
||||
},
|
||||
nil,
|
||||
)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 2)
|
||||
|
@ -96,7 +97,7 @@ func TestCollectFileSources(t *testing.T) {
|
|||
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
_, err = collectFilesystemSources(fsys, true, nil)
|
||||
_, err = collectFilesystemSources(fsys, true, nil, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
||||
})
|
||||
|
@ -108,7 +109,7 @@ func TestCollectFileSources(t *testing.T) {
|
|||
"4_qux.sql": sqlMapFile,
|
||||
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
||||
}
|
||||
sources, err := collectFilesystemSources(mapFS, false, nil)
|
||||
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 4)
|
||||
check.Number(t, len(sources.goSources), 0)
|
||||
|
@ -123,7 +124,7 @@ func TestCollectFileSources(t *testing.T) {
|
|||
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
|
||||
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
||||
}
|
||||
sources, err := collectFilesystemSources(mapFS, false, nil)
|
||||
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(sources.sqlSources), 2)
|
||||
check.Number(t, len(sources.goSources), 1)
|
||||
|
@ -142,7 +143,7 @@ func TestCollectFileSources(t *testing.T) {
|
|||
"001_foo.sql": sqlMapFile,
|
||||
"01_bar.sql": sqlMapFile,
|
||||
}
|
||||
_, err := collectFilesystemSources(mapFS, false, nil)
|
||||
_, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||
})
|
||||
|
@ -158,7 +159,7 @@ func TestCollectFileSources(t *testing.T) {
|
|||
t.Helper()
|
||||
f, err := fs.Sub(mapFS, dirpath)
|
||||
check.NoError(t, err)
|
||||
got, err := collectFilesystemSources(f, false, nil)
|
||||
got, err := collectFilesystemSources(f, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(got.sqlSources), len(sqlSources))
|
||||
check.Number(t, len(got.goSources), 0)
|
||||
|
@ -194,27 +195,21 @@ func TestMerge(t *testing.T) {
|
|||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
check.Equal(t, len(sources.sqlSources), 1)
|
||||
check.Equal(t, len(sources.goSources), 2)
|
||||
src1 := sources.lookup(TypeSQL, 1)
|
||||
check.Bool(t, src1 != nil, true)
|
||||
src2 := sources.lookup(TypeGo, 2)
|
||||
check.Bool(t, src2 != nil, true)
|
||||
src3 := sources.lookup(TypeGo, 3)
|
||||
check.Bool(t, src3 != nil, true)
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
migrations, err := merge(sources, map[int64]*goMigration{
|
||||
2: newGoMigration("", nil, nil),
|
||||
3: newGoMigration("", nil, nil),
|
||||
})
|
||||
registered := map[int64]*Migration{
|
||||
2: NewGoMigration(2, nil, nil),
|
||||
3: NewGoMigration(3, nil, nil),
|
||||
}
|
||||
migrations, err := merge(sources, registered)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(migrations), 3)
|
||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
||||
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
|
||||
assertMigration(t, migrations[1], newSource(TypeGo, "", 2))
|
||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||
})
|
||||
t.Run("unregistered_all", func(t *testing.T) {
|
||||
_, err := merge(sources, nil)
|
||||
|
@ -224,18 +219,16 @@ func TestMerge(t *testing.T) {
|
|||
check.Contains(t, err.Error(), "00003_baz.go")
|
||||
})
|
||||
t.Run("unregistered_some", func(t *testing.T) {
|
||||
_, err := merge(sources, map[int64]*goMigration{
|
||||
2: newGoMigration("", nil, nil),
|
||||
})
|
||||
_, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
||||
check.Contains(t, err.Error(), "00003_baz.go")
|
||||
})
|
||||
t.Run("duplicate_sql", func(t *testing.T) {
|
||||
_, err := merge(sources, map[int64]*goMigration{
|
||||
1: newGoMigration("", nil, nil), // duplicate. SQL already exists.
|
||||
2: newGoMigration("", nil, nil),
|
||||
3: newGoMigration("", nil, nil),
|
||||
_, err := merge(sources, map[int64]*Migration{
|
||||
1: NewGoMigration(1, nil, nil), // duplicate. SQL already exists.
|
||||
2: NewGoMigration(2, nil, nil),
|
||||
3: NewGoMigration(3, nil, nil),
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||
|
@ -250,13 +243,13 @@ func TestMerge(t *testing.T) {
|
|||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
t.Run("unregistered_all", func(t *testing.T) {
|
||||
migrations, err := merge(sources, map[int64]*goMigration{
|
||||
3: newGoMigration("", nil, nil),
|
||||
migrations, err := merge(sources, map[int64]*Migration{
|
||||
3: NewGoMigration(3, nil, nil),
|
||||
// 4 is missing
|
||||
6: newGoMigration("", nil, nil),
|
||||
6: NewGoMigration(6, nil, nil),
|
||||
})
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(migrations), 5)
|
||||
|
@ -274,20 +267,20 @@ func TestMerge(t *testing.T) {
|
|||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||
check.NoError(t, err)
|
||||
t.Run("unregistered_all", func(t *testing.T) {
|
||||
migrations, err := merge(sources, map[int64]*goMigration{
|
||||
migrations, err := merge(sources, map[int64]*Migration{
|
||||
// This is the only Go file on disk.
|
||||
2: newGoMigration("", nil, nil),
|
||||
2: NewGoMigration(2, nil, nil),
|
||||
// These are not on disk. Explicitly registered.
|
||||
3: newGoMigration("", nil, nil),
|
||||
6: newGoMigration("", nil, nil),
|
||||
3: NewGoMigration(3, nil, nil),
|
||||
6: NewGoMigration(6, nil, nil),
|
||||
})
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(migrations), 4)
|
||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
||||
assertMigration(t, migrations[1], newSource(TypeGo, "", 2))
|
||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
|
||||
})
|
||||
|
@ -308,15 +301,15 @@ func TestCheckMissingMigrations(t *testing.T) {
|
|||
{Version: 5},
|
||||
{Version: 7}, // <-- database max version_id
|
||||
}
|
||||
fsMigrations := []*migration{
|
||||
newMigrationVersion(1),
|
||||
newMigrationVersion(2), // missing migration
|
||||
newMigrationVersion(3),
|
||||
newMigrationVersion(4),
|
||||
newMigrationVersion(5),
|
||||
newMigrationVersion(6), // missing migration
|
||||
newMigrationVersion(7), // ----- database max version_id -----
|
||||
newMigrationVersion(8), // new migration
|
||||
fsMigrations := []*Migration{
|
||||
newSQLMigration(Source{Version: 1}),
|
||||
newSQLMigration(Source{Version: 2}), // missing migration
|
||||
newSQLMigration(Source{Version: 3}),
|
||||
newSQLMigration(Source{Version: 4}),
|
||||
newSQLMigration(Source{Version: 5}),
|
||||
newSQLMigration(Source{Version: 6}), // missing migration
|
||||
newSQLMigration(Source{Version: 7}), // ----- database max version_id -----
|
||||
newSQLMigration(Source{Version: 8}), // new migration
|
||||
}
|
||||
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
||||
check.Number(t, len(got), 2)
|
||||
|
@ -334,9 +327,9 @@ func TestCheckMissingMigrations(t *testing.T) {
|
|||
{Version: 5},
|
||||
{Version: 2},
|
||||
}
|
||||
fsMigrations := []*migration{
|
||||
newMigrationVersion(3), // new migration
|
||||
newMigrationVersion(4), // new migration
|
||||
fsMigrations := []*Migration{
|
||||
NewGoMigration(3, nil, nil), // new migration
|
||||
NewGoMigration(4, nil, nil), // new migration
|
||||
}
|
||||
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
||||
check.Number(t, len(got), 2)
|
||||
|
@ -345,24 +338,19 @@ func TestCheckMissingMigrations(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func newMigrationVersion(version int64) *migration {
|
||||
return &migration{
|
||||
Source: Source{
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func assertMigration(t *testing.T, got *migration, want Source) {
|
||||
func assertMigration(t *testing.T, got *Migration, want Source) {
|
||||
t.Helper()
|
||||
check.Equal(t, got.Source, want)
|
||||
switch got.Source.Type {
|
||||
check.Equal(t, got.Type, want.Type)
|
||||
check.Equal(t, got.Version, want.Version)
|
||||
check.Equal(t, got.Source, want.Path)
|
||||
switch got.Type {
|
||||
case TypeGo:
|
||||
check.Bool(t, got.Go != nil, true)
|
||||
check.Bool(t, got.goUp != nil, true)
|
||||
check.Bool(t, got.goDown != nil, true)
|
||||
case TypeSQL:
|
||||
check.Bool(t, got.SQL == nil, true)
|
||||
check.Bool(t, got.sql.Parsed, false)
|
||||
default:
|
||||
t.Fatalf("unknown migration type: %s", got.Source.Type)
|
||||
t.Fatalf("unknown migration type: %s", got.Type)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package provider
|
||||
package goose
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -16,8 +16,8 @@ var (
|
|||
// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
|
||||
ErrNoMigrations = errors.New("no migrations found")
|
||||
|
||||
// ErrNoNextVersion when the next migration version is not found.
|
||||
ErrNoNextVersion = errors.New("no next version found")
|
||||
// errInvalidVersion is returned when a migration version is invalid.
|
||||
errInvalidVersion = errors.New("version must be greater than 0")
|
||||
)
|
||||
|
||||
// PartialError is returned when a migration fails, but some migrations already got applied.
|
|
@ -1,8 +1,6 @@
|
|||
package provider
|
||||
package goose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -12,12 +10,11 @@ import (
|
|||
|
||||
const (
|
||||
// DefaultTablename is the default name of the database table used to track history of applied
|
||||
// migrations. It can be overridden using the [WithTableName] option when creating a new
|
||||
// provider.
|
||||
// migrations.
|
||||
DefaultTablename = "goose_db_version"
|
||||
)
|
||||
|
||||
// ProviderOption is a configuration option for a goose provider.
|
||||
// ProviderOption is a configuration option for a goose goose.
|
||||
type ProviderOption interface {
|
||||
apply(*config) error
|
||||
}
|
||||
|
@ -84,84 +81,75 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
|
|||
})
|
||||
}
|
||||
|
||||
// WithExcludes excludes the given file names from the list of migrations.
|
||||
//
|
||||
// If WithExcludes is called multiple times, the list of excludes is merged.
|
||||
func WithExcludes(excludes []string) ProviderOption {
|
||||
// WithExcludeNames excludes the given file name from the list of migrations. If called multiple
|
||||
// times, the list of excludes is merged.
|
||||
func WithExcludeNames(excludes []string) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
for _, name := range excludes {
|
||||
c.excludes[name] = true
|
||||
if _, ok := c.excludePaths[name]; ok {
|
||||
return fmt.Errorf("duplicate exclude file name: %s", name)
|
||||
}
|
||||
c.excludePaths[name] = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GoMigrationFunc is a user-defined Go migration, registered using the option [WithGoMigration].
|
||||
type GoMigrationFunc struct {
|
||||
// One of the following must be set:
|
||||
Run func(context.Context, *sql.Tx) error
|
||||
// -- OR --
|
||||
RunNoTx func(context.Context, *sql.DB) error
|
||||
}
|
||||
|
||||
// WithGoMigration registers a Go migration with the given version.
|
||||
//
|
||||
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
|
||||
// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
|
||||
// set.
|
||||
func WithGoMigration(version int64, up, down *GoMigrationFunc) ProviderOption {
|
||||
// WithExcludeVersions excludes the given versions from the list of migrations. If called multiple
|
||||
// times, the list of excludes is merged.
|
||||
func WithExcludeVersions(versions []int64) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
if version < 1 {
|
||||
return errors.New("version must be greater than zero")
|
||||
}
|
||||
if _, ok := c.registered[version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", version)
|
||||
}
|
||||
// Allow nil up/down functions. This enables users to apply "no-op" migrations, while
|
||||
// versioning them.
|
||||
if up != nil {
|
||||
if up.Run == nil && up.RunNoTx == nil {
|
||||
return fmt.Errorf("go migration with version %d must have an up function", version)
|
||||
for _, version := range versions {
|
||||
if version < 1 {
|
||||
return errInvalidVersion
|
||||
}
|
||||
if up.Run != nil && up.RunNoTx != nil {
|
||||
return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version)
|
||||
if _, ok := c.excludeVersions[version]; ok {
|
||||
return fmt.Errorf("duplicate excludes version: %d", version)
|
||||
}
|
||||
}
|
||||
if down != nil {
|
||||
if down.Run == nil && down.RunNoTx == nil {
|
||||
return fmt.Errorf("go migration with version %d must have a down function", version)
|
||||
}
|
||||
if down.Run != nil && down.RunNoTx != nil {
|
||||
return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version)
|
||||
}
|
||||
}
|
||||
c.registered[version] = &goMigration{
|
||||
up: up,
|
||||
down: down,
|
||||
c.excludeVersions[version] = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithAllowedMissing allows the provider to apply missing (out-of-order) migrations. By default,
|
||||
// WithGoMigrations registers Go migrations with the provider. If a Go migration with the same
|
||||
// version has already been registered, an error will be returned.
|
||||
//
|
||||
// Go migrations must be constructed using the [NewGoMigration] function.
|
||||
func WithGoMigrations(migrations ...*Migration) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
for _, m := range migrations {
|
||||
if _, ok := c.registered[m.Version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||
}
|
||||
if err := checkGoMigration(m); err != nil {
|
||||
return fmt.Errorf("invalid go migration: %w", err)
|
||||
}
|
||||
c.registered[m.Version] = m
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithAllowOutofOrder allows the provider to apply missing (out-of-order) migrations. By default,
|
||||
// goose will raise an error if it encounters a missing migration.
|
||||
//
|
||||
// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
|
||||
// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
|
||||
// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
|
||||
// by new migrations.
|
||||
func WithAllowedMissing(b bool) ProviderOption {
|
||||
// For example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is
|
||||
// true, then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order
|
||||
// of applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first,
|
||||
// followed by new migrations.
|
||||
func WithAllowOutofOrder(b bool) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
c.allowMissing = b
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithDisabledVersioning disables versioning. Disabling versioning allows applying migrations
|
||||
// WithDisableVersioning disables versioning. Disabling versioning allows applying migrations
|
||||
// without tracking the versions in the database schema table. Useful for tests, seeding a database
|
||||
// or running ad-hoc queries. By default, goose will track all versions in the database schema
|
||||
// table.
|
||||
func WithDisabledVersioning(b bool) ProviderOption {
|
||||
func WithDisableVersioning(b bool) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
c.disableVersioning = b
|
||||
return nil
|
||||
|
@ -171,12 +159,13 @@ func WithDisabledVersioning(b bool) ProviderOption {
|
|||
type config struct {
|
||||
store database.Store
|
||||
|
||||
verbose bool
|
||||
excludes map[string]bool
|
||||
verbose bool
|
||||
excludePaths map[string]bool
|
||||
excludeVersions map[int64]bool
|
||||
|
||||
// Go migrations registered by the user. These will be merged/resolved with migrations from the
|
||||
// filesystem and init() functions.
|
||||
registered map[int64]*goMigration
|
||||
// Go migrations registered by the user. These will be merged/resolved against the globally
|
||||
// registered migrations.
|
||||
registered map[int64]*Migration
|
||||
|
||||
// Locking options
|
||||
lockEnabled bool
|
|
@ -1,4 +1,4 @@
|
|||
package provider_test
|
||||
package goose_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
@ -6,9 +6,9 @@ import (
|
|||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/provider"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
|
@ -24,45 +24,42 @@ func TestNewProvider(t *testing.T) {
|
|||
}
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
// Empty dialect not allowed
|
||||
_, err = provider.NewProvider("", db, fsys)
|
||||
_, err = goose.NewProvider("", db, fsys)
|
||||
check.HasError(t, err)
|
||||
// Invalid dialect not allowed
|
||||
_, err = provider.NewProvider("unknown-dialect", db, fsys)
|
||||
_, err = goose.NewProvider("unknown-dialect", db, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil db not allowed
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, nil, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil fsys not allowed
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil)
|
||||
_, err = goose.NewProvider(database.DialectSQLite3, nil, fsys)
|
||||
check.HasError(t, err)
|
||||
// Nil store not allowed
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(nil))
|
||||
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(nil))
|
||||
check.HasError(t, err)
|
||||
// Cannot set both dialect and store
|
||||
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
|
||||
check.NoError(t, err)
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(store))
|
||||
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(store))
|
||||
check.HasError(t, err)
|
||||
// Multiple stores not allowed
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
|
||||
provider.WithStore(store),
|
||||
provider.WithStore(store),
|
||||
_, err = goose.NewProvider(database.DialectSQLite3, db, nil,
|
||||
goose.WithStore(store),
|
||||
goose.WithStore(store),
|
||||
)
|
||||
check.HasError(t, err)
|
||||
})
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
// Valid dialect, db, and fsys allowed
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys)
|
||||
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys)
|
||||
check.NoError(t, err)
|
||||
// Valid dialect, db, fsys, and verbose allowed
|
||||
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys,
|
||||
provider.WithVerbose(testing.Verbose()),
|
||||
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys,
|
||||
goose.WithVerbose(testing.Verbose()),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
// Custom store allowed
|
||||
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
|
||||
check.NoError(t, err)
|
||||
_, err = provider.NewProvider("", db, nil, provider.WithStore(store))
|
||||
_, err = goose.NewProvider("", db, nil, goose.WithStore(store))
|
||||
check.HasError(t, err)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,460 @@
|
|||
package goose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingZeroVersion = errors.New("missing zero version migration")
|
||||
)
|
||||
|
||||
func (p *Provider) resolveUpMigrations(
|
||||
dbVersions []*database.ListMigrationsResult,
|
||||
version int64,
|
||||
) ([]*Migration, error) {
|
||||
var apply []*Migration
|
||||
var dbMaxVersion int64
|
||||
// dbAppliedVersions is a map of all applied migrations in the database.
|
||||
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
|
||||
for _, m := range dbVersions {
|
||||
dbAppliedVersions[m.Version] = true
|
||||
if m.Version > dbMaxVersion {
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
|
||||
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
|
||||
// migrations entirely. At the moment this is not supported, but leaving this comment because
|
||||
// that's where that logic would be handled.
|
||||
//
|
||||
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
|
||||
// sure if this is a common use case, but it's possible.
|
||||
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
||||
var collected []string
|
||||
for _, v := range missingMigrations {
|
||||
collected = append(collected, fmt.Sprintf("%d", v.versionID))
|
||||
}
|
||||
msg := "migration"
|
||||
if len(collected) > 1 {
|
||||
msg += "s"
|
||||
}
|
||||
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
|
||||
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
|
||||
)
|
||||
}
|
||||
for _, v := range missingMigrations {
|
||||
m, err := p.getMigration(v.versionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apply = append(apply, m)
|
||||
}
|
||||
// filter all migrations with a version greater than the supplied version (min) and less than or
|
||||
// equal to the requested version (max). Skip any migrations that have already been applied.
|
||||
for _, m := range p.migrations {
|
||||
if dbAppliedVersions[m.Version] {
|
||||
continue
|
||||
}
|
||||
if m.Version > dbMaxVersion && m.Version <= version {
|
||||
apply = append(apply, m)
|
||||
}
|
||||
}
|
||||
return apply, nil
|
||||
}
|
||||
|
||||
func (p *Provider) prepareMigration(fsys fs.FS, m *Migration, direction bool) error {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
if m.goUp.Mode == 0 {
|
||||
return errors.New("go up migration mode is not set")
|
||||
}
|
||||
if m.goDown.Mode == 0 {
|
||||
return errors.New("go down migration mode is not set")
|
||||
}
|
||||
var useTx bool
|
||||
if direction {
|
||||
useTx = m.goUp.Mode == TransactionEnabled
|
||||
} else {
|
||||
useTx = m.goDown.Mode == TransactionEnabled
|
||||
}
|
||||
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB,
|
||||
// but are locking the database with *sql.Conn. If the caller sets max open connections to
|
||||
// 1, then this will deadlock because the Go migration will try to acquire a connection from
|
||||
// the pool, but the pool is exhausted because the lock is held.
|
||||
//
|
||||
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to
|
||||
// use *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is
|
||||
// a bit of an edge case. For now, we guard against this scenario by checking the max open
|
||||
// connections and returning an error.
|
||||
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
|
||||
if !useTx {
|
||||
return errors.New("potential deadlock detected: cannot run Go migration without a transaction when max open connections set to 1")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case TypeSQL:
|
||||
if m.sql.Parsed {
|
||||
return nil
|
||||
}
|
||||
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.sql.Parsed = true
|
||||
m.sql.UseTx = parsed.UseTx
|
||||
m.sql.Up, m.sql.Down = parsed.Up, parsed.Down
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid migration type: %+v", m)
|
||||
}
|
||||
|
||||
// runMigrations runs migrations sequentially in the given direction. If the migrations list is
|
||||
// empty, return nil without error.
|
||||
func (p *Provider) runMigrations(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
migrations []*Migration,
|
||||
direction sqlparser.Direction,
|
||||
byOne bool,
|
||||
) ([]*MigrationResult, error) {
|
||||
if len(migrations) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
apply := migrations
|
||||
if byOne {
|
||||
apply = migrations[:1]
|
||||
}
|
||||
|
||||
// SQL migrations are lazily parsed in both directions. This is done before attempting to run
|
||||
// any migrations to catch errors early and prevent leaving the database in an incomplete state.
|
||||
|
||||
for _, m := range apply {
|
||||
if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
|
||||
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
|
||||
// to run in a transaction.
|
||||
|
||||
// feat(mf): this is where we can (optionally) group multiple migrations to be run in a single
|
||||
// transaction. The default is to apply each migration sequentially on its own. See the
|
||||
// following issues for more details:
|
||||
// - https://github.com/pressly/goose/issues/485
|
||||
// - https://github.com/pressly/goose/issues/222
|
||||
//
|
||||
// Be careful, we can't use a single transaction for all migrations because some may be marked
|
||||
// as not using a transaction.
|
||||
|
||||
var results []*MigrationResult
|
||||
for _, m := range apply {
|
||||
current := &MigrationResult{
|
||||
Source: &Source{
|
||||
Type: m.Type,
|
||||
Path: m.Source,
|
||||
Version: m.Version,
|
||||
},
|
||||
Direction: direction.String(),
|
||||
Empty: isEmpty(m, direction.ToBool()),
|
||||
}
|
||||
start := time.Now()
|
||||
if err := p.runIndividually(ctx, conn, m, direction.ToBool()); err != nil {
|
||||
// TODO(mf): we should also return the pending migrations here, the remaining items in
|
||||
// the apply slice.
|
||||
current.Error = err
|
||||
current.Duration = time.Since(start)
|
||||
return nil, &PartialError{
|
||||
Applied: results,
|
||||
Failed: current,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
current.Duration = time.Since(start)
|
||||
results = append(results, current)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (p *Provider) runIndividually(
|
||||
ctx context.Context,
|
||||
conn *sql.Conn,
|
||||
m *Migration,
|
||||
direction bool,
|
||||
) error {
|
||||
useTx, err := useTx(m, direction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if useTx {
|
||||
return beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||
if err := runMigration(ctx, tx, m, direction); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.maybeInsertOrDelete(ctx, tx, m.Version, direction)
|
||||
})
|
||||
}
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
// Note, we are using *sql.DB instead of *sql.Conn because it's the Go migration contract.
|
||||
// This may be a deadlock scenario if max open connections is set to 1 AND a lock is
|
||||
// acquired on the database. In this case, the migration will block forever unable to
|
||||
// acquire a connection from the pool.
|
||||
//
|
||||
// For now, we guard against this scenario by checking the max open connections and
|
||||
// returning an error in the prepareMigration function.
|
||||
if err := runMigration(ctx, p.db, m, direction); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.maybeInsertOrDelete(ctx, p.db, m.Version, direction)
|
||||
case TypeSQL:
|
||||
if err := runMigration(ctx, conn, m, direction); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.maybeInsertOrDelete(ctx, conn, m.Version, direction)
|
||||
}
|
||||
return fmt.Errorf("failed to run individual migration: neither sql or go: %v", m)
|
||||
}
|
||||
|
||||
func (p *Provider) maybeInsertOrDelete(
|
||||
ctx context.Context,
|
||||
db database.DBTxConn,
|
||||
version int64,
|
||||
direction bool,
|
||||
) error {
|
||||
// If versioning is disabled, we don't need to insert or delete the migration version.
|
||||
if p.cfg.disableVersioning {
|
||||
return nil
|
||||
}
|
||||
if direction {
|
||||
return p.store.Insert(ctx, db, database.InsertRequest{Version: version})
|
||||
}
|
||||
return p.store.Delete(ctx, db, version)
|
||||
}
|
||||
|
||||
// beginTx begins a transaction and runs the given function. If the function returns an error, the
|
||||
// transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (retErr error) {
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
retErr = multierr.Append(retErr, tx.Rollback())
|
||||
}
|
||||
}()
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
|
||||
p.mu.Lock()
|
||||
conn, err := p.db.Conn(ctx)
|
||||
if err != nil {
|
||||
p.mu.Unlock()
|
||||
return nil, nil, err
|
||||
}
|
||||
// cleanup is a function that cleans up the connection, and optionally, the session lock.
|
||||
cleanup := func() error {
|
||||
p.mu.Unlock()
|
||||
return conn.Close()
|
||||
}
|
||||
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
|
||||
if err := l.SessionLock(ctx, conn); err != nil {
|
||||
return nil, nil, multierr.Append(err, cleanup())
|
||||
}
|
||||
// A lock was acquired, so we need to unlock the session when we're done. This is done by
|
||||
// returning a cleanup function that unlocks the session and closes the connection.
|
||||
cleanup = func() error {
|
||||
p.mu.Unlock()
|
||||
// Use a detached context to unlock the session. This is because the context passed to
|
||||
// SessionLock may have been canceled, and we don't want to cancel the unlock.
|
||||
//
|
||||
// TODO(mf): use context.WithoutCancel added in go1.21
|
||||
detachedCtx := context.Background()
|
||||
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
|
||||
}
|
||||
}
|
||||
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
|
||||
// need the version table because no versions are being recorded.
|
||||
if !p.cfg.disableVersioning {
|
||||
if err := p.ensureVersionTable(ctx, conn); err != nil {
|
||||
return nil, nil, multierr.Append(err, cleanup())
|
||||
}
|
||||
}
|
||||
return conn, cleanup, nil
|
||||
}
|
||||
|
||||
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
|
||||
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
|
||||
// from a table that may not exist. https://github.com/pressly/goose/issues/461
|
||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||
if err == nil && res != nil {
|
||||
return nil
|
||||
}
|
||||
return beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.cfg.disableVersioning {
|
||||
return nil
|
||||
}
|
||||
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
|
||||
})
|
||||
}
|
||||
|
||||
type missingMigration struct {
|
||||
versionID int64
|
||||
}
|
||||
|
||||
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
|
||||
// migration is one that has a version less than the max version in the database.
|
||||
func checkMissingMigrations(
|
||||
dbMigrations []*database.ListMigrationsResult,
|
||||
fsMigrations []*Migration,
|
||||
) []missingMigration {
|
||||
existing := make(map[int64]bool)
|
||||
var dbMaxVersion int64
|
||||
for _, m := range dbMigrations {
|
||||
existing[m.Version] = true
|
||||
if m.Version > dbMaxVersion {
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
var missing []missingMigration
|
||||
for _, m := range fsMigrations {
|
||||
version := m.Version
|
||||
if !existing[version] && version < dbMaxVersion {
|
||||
missing = append(missing, missingMigration{
|
||||
versionID: version,
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Slice(missing, func(i, j int) bool {
|
||||
return missing[i].versionID < missing[j].versionID
|
||||
})
|
||||
return missing
|
||||
}
|
||||
|
||||
// getMigration returns the migration for the given version. If no migration is found, then
|
||||
// ErrVersionNotFound is returned.
|
||||
func (p *Provider) getMigration(version int64) (*Migration, error) {
|
||||
for _, m := range p.migrations {
|
||||
if m.Version == version {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrVersionNotFound
|
||||
}
|
||||
|
||||
// useTx is a helper function that returns true if the migration should be run in a transaction. It
|
||||
// must only be called after the migration has been parsed and initialized.
|
||||
func useTx(m *Migration, direction bool) (bool, error) {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
if m.goUp.Mode == 0 || m.goDown.Mode == 0 {
|
||||
return false, fmt.Errorf("go migrations must have a mode set")
|
||||
}
|
||||
if direction {
|
||||
return m.goUp.Mode == TransactionEnabled, nil
|
||||
}
|
||||
return m.goDown.Mode == TransactionEnabled, nil
|
||||
case TypeSQL:
|
||||
if !m.sql.Parsed {
|
||||
return false, fmt.Errorf("sql migrations must be parsed")
|
||||
}
|
||||
return m.sql.UseTx, nil
|
||||
}
|
||||
return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type)
|
||||
}
|
||||
|
||||
// isEmpty is a helper function that returns true if the migration has no functions or no statements
|
||||
// to execute. It must only be called after the migration has been parsed and initialized.
|
||||
func isEmpty(m *Migration, direction bool) bool {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
if direction {
|
||||
return m.goUp.RunTx == nil && m.goUp.RunDB == nil
|
||||
}
|
||||
return m.goDown.RunTx == nil && m.goDown.RunDB == nil
|
||||
case TypeSQL:
|
||||
if direction {
|
||||
return len(m.sql.Up) == 0
|
||||
}
|
||||
return len(m.sql.Down) == 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// runMigration is a helper function that runs the migration in the given direction. It must only be
|
||||
// called after the migration has been parsed and initialized.
|
||||
func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
return runGo(ctx, db, m, direction)
|
||||
case TypeSQL:
|
||||
return runSQL(ctx, db, m, direction)
|
||||
}
|
||||
return fmt.Errorf("invalid migration type: %q", m.Type)
|
||||
}
|
||||
|
||||
// runGo is a helper function that runs the given Go functions in the given direction. It must only
|
||||
// be called after the migration has been initialized.
|
||||
func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
|
||||
switch db := db.(type) {
|
||||
case *sql.Conn:
|
||||
return fmt.Errorf("go migrations are not supported with *sql.Conn")
|
||||
case *sql.DB:
|
||||
if direction && m.goUp.RunDB != nil {
|
||||
return m.goUp.RunDB(ctx, db)
|
||||
}
|
||||
if !direction && m.goDown.RunDB != nil {
|
||||
return m.goDown.RunDB(ctx, db)
|
||||
}
|
||||
return nil
|
||||
case *sql.Tx:
|
||||
if direction && m.goUp.RunTx != nil {
|
||||
return m.goUp.RunTx(ctx, db)
|
||||
}
|
||||
if !direction && m.goDown.RunTx != nil {
|
||||
return m.goDown.RunTx(ctx, db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid database connection type: %T", db)
|
||||
}
|
||||
|
||||
// runSQL is a helper function that runs the given SQL statements in the given direction. It must
|
||||
// only be called after the migration has been parsed.
|
||||
func runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
|
||||
if !m.sql.Parsed {
|
||||
return fmt.Errorf("sql migrations must be parsed")
|
||||
}
|
||||
var statements []string
|
||||
if direction {
|
||||
statements = m.sql.Up
|
||||
} else {
|
||||
statements = m.sql.Down
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package provider_test
|
||||
package goose_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -16,9 +16,9 @@ import (
|
|||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/provider"
|
||||
"github.com/pressly/goose/v3/internal/testdb"
|
||||
"github.com/pressly/goose/v3/lock"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
@ -45,22 +45,22 @@ func TestProviderRun(t *testing.T) {
|
|||
p, _ := newProviderWithDB(t)
|
||||
_, err := p.ApplyVersion(context.Background(), 999, true)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true)
|
||||
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
|
||||
_, err = p.ApplyVersion(context.Background(), 999, false)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true)
|
||||
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
|
||||
})
|
||||
t.Run("run_zero", func(t *testing.T) {
|
||||
p, _ := newProviderWithDB(t)
|
||||
_, err := p.UpTo(context.Background(), 0)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
|
||||
check.Equal(t, err.Error(), "version must be greater than 0")
|
||||
_, err = p.DownTo(context.Background(), -1)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
|
||||
_, err = p.ApplyVersion(context.Background(), 0, true)
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
|
||||
check.Equal(t, err.Error(), "version must be greater than 0")
|
||||
})
|
||||
t.Run("up_and_down_all", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
@ -72,30 +72,30 @@ func TestProviderRun(t *testing.T) {
|
|||
check.Number(t, len(sources), numCount)
|
||||
// Ensure only SQL migrations are returned
|
||||
for _, s := range sources {
|
||||
check.Equal(t, s.Type, provider.TypeSQL)
|
||||
check.Equal(t, s.Type, goose.TypeSQL)
|
||||
}
|
||||
// Test Up
|
||||
res, err := p.Up(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), numCount)
|
||||
assertResult(t, res[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, res[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||
assertResult(t, res[2], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
||||
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
|
||||
assertResult(t, res[4], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
|
||||
assertResult(t, res[5], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
|
||||
assertResult(t, res[6], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
||||
assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||
assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
||||
assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "up", false)
|
||||
assertResult(t, res[4], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "up", false)
|
||||
assertResult(t, res[5], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "up", true)
|
||||
assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
||||
// Test Down
|
||||
res, err = p.DownTo(ctx, 0)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(res), numCount)
|
||||
assertResult(t, res[0], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
||||
assertResult(t, res[1], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
||||
assertResult(t, res[2], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
||||
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
|
||||
assertResult(t, res[4], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
|
||||
assertResult(t, res[5], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
|
||||
assertResult(t, res[6], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
|
||||
assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
||||
assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
||||
assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
||||
assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "down", false)
|
||||
assertResult(t, res[4], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "down", false)
|
||||
assertResult(t, res[5], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "down", false)
|
||||
assertResult(t, res[6], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "down", false)
|
||||
})
|
||||
t.Run("up_and_down_by_one", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
@ -107,8 +107,8 @@ func TestProviderRun(t *testing.T) {
|
|||
res, err := p.UpByOne(ctx)
|
||||
counter++
|
||||
if counter > maxVersion {
|
||||
if !errors.Is(err, provider.ErrNoNextVersion) {
|
||||
t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion)
|
||||
if !errors.Is(err, goose.ErrNoNextVersion) {
|
||||
t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
@ -126,8 +126,8 @@ func TestProviderRun(t *testing.T) {
|
|||
res, err := p.Down(ctx)
|
||||
counter++
|
||||
if counter > maxVersion {
|
||||
if !errors.Is(err, provider.ErrNoNextVersion) {
|
||||
t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion)
|
||||
if !errors.Is(err, goose.ErrNoNextVersion) {
|
||||
t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
@ -149,14 +149,14 @@ func TestProviderRun(t *testing.T) {
|
|||
results, err := p.UpTo(ctx, upToVersion)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(results), upToVersion)
|
||||
assertResult(t, results[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, results[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||
assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||
// Fetch the goose version from DB
|
||||
currentVersion, err := p.GetDBVersion(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, currentVersion, upToVersion)
|
||||
// Validate the version actually matches what goose claims it is
|
||||
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
|
||||
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, gotVersion, upToVersion)
|
||||
})
|
||||
|
@ -197,7 +197,7 @@ func TestProviderRun(t *testing.T) {
|
|||
check.NoError(t, err)
|
||||
check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version)
|
||||
// Validate the db migration version actually matches what goose claims it is
|
||||
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
|
||||
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, gotVersion, currentVersion)
|
||||
tables, err := getTableNames(db)
|
||||
|
@ -213,13 +213,13 @@ func TestProviderRun(t *testing.T) {
|
|||
downResult, err := p.DownTo(ctx, 0)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(downResult), len(sources))
|
||||
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
|
||||
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, gotVersion, 0)
|
||||
// Should only be left with a single table, the default goose table
|
||||
tables, err := getTableNames(db)
|
||||
check.NoError(t, err)
|
||||
knownTables := []string{provider.DefaultTablename, "sqlite_sequence"}
|
||||
knownTables := []string{goose.DefaultTablename, "sqlite_sequence"}
|
||||
if !reflect.DeepEqual(tables, knownTables) {
|
||||
t.Logf("got tables: %v", tables)
|
||||
t.Logf("known tables: %v", knownTables)
|
||||
|
@ -261,7 +261,7 @@ func TestProviderRun(t *testing.T) {
|
|||
check.NoError(t, err)
|
||||
_, err = p.ApplyVersion(ctx, 1, true)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true)
|
||||
check.Bool(t, errors.Is(err, goose.ErrAlreadyApplied), true)
|
||||
check.Contains(t, err.Error(), "version 1: already applied")
|
||||
})
|
||||
t.Run("status", func(t *testing.T) {
|
||||
|
@ -272,26 +272,26 @@ func TestProviderRun(t *testing.T) {
|
|||
status, err := p.Status(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(status), numCount)
|
||||
assertStatus(t, status[0], provider.StatePending, newSource(provider.TypeSQL, "00001_users_table.sql", 1), true)
|
||||
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), true)
|
||||
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), true)
|
||||
assertStatus(t, status[3], provider.StatePending, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), true)
|
||||
assertStatus(t, status[4], provider.StatePending, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), true)
|
||||
assertStatus(t, status[5], provider.StatePending, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), true)
|
||||
assertStatus(t, status[6], provider.StatePending, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true)
|
||||
assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true)
|
||||
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true)
|
||||
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true)
|
||||
assertStatus(t, status[3], goose.StatePending, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), true)
|
||||
assertStatus(t, status[4], goose.StatePending, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), true)
|
||||
assertStatus(t, status[5], goose.StatePending, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), true)
|
||||
assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true)
|
||||
// Apply all migrations
|
||||
_, err = p.Up(ctx)
|
||||
check.NoError(t, err)
|
||||
status, err = p.Status(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(status), numCount)
|
||||
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
|
||||
assertStatus(t, status[1], provider.StateApplied, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), false)
|
||||
assertStatus(t, status[2], provider.StateApplied, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), false)
|
||||
assertStatus(t, status[3], provider.StateApplied, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), false)
|
||||
assertStatus(t, status[4], provider.StateApplied, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), false)
|
||||
assertStatus(t, status[5], provider.StateApplied, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), false)
|
||||
assertStatus(t, status[6], provider.StateApplied, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false)
|
||||
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
||||
assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false)
|
||||
assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false)
|
||||
assertStatus(t, status[3], goose.StateApplied, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), false)
|
||||
assertStatus(t, status[4], goose.StateApplied, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), false)
|
||||
assertStatus(t, status[5], goose.StateApplied, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), false)
|
||||
assertStatus(t, status[6], goose.StateApplied, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), false)
|
||||
})
|
||||
t.Run("tx_partial_errors", func(t *testing.T) {
|
||||
countOwners := func(db *sql.DB) (int, error) {
|
||||
|
@ -321,22 +321,22 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2');
|
|||
INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
||||
`),
|
||||
}
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, mapFS)
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, mapFS)
|
||||
check.NoError(t, err)
|
||||
_, err = p.Up(ctx)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
|
||||
var expected *provider.PartialError
|
||||
var expected *goose.PartialError
|
||||
check.Bool(t, errors.As(err, &expected), true)
|
||||
// Check Err field
|
||||
check.Bool(t, expected.Err != nil, true)
|
||||
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||
// Check Results field
|
||||
check.Number(t, len(expected.Applied), 1)
|
||||
assertResult(t, expected.Applied[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||
// Check Failed field
|
||||
check.Bool(t, expected.Failed != nil, true)
|
||||
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
|
||||
assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2)
|
||||
check.Bool(t, expected.Failed.Empty, false)
|
||||
check.Bool(t, expected.Failed.Error != nil, true)
|
||||
check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||
|
@ -351,9 +351,9 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
|||
status, err := p.Status(ctx)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(status), 3)
|
||||
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
|
||||
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_partial_error.sql", 2), true)
|
||||
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_insert_data.sql", 3), true)
|
||||
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
||||
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true)
|
||||
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -415,7 +415,7 @@ func TestConcurrentProvider(t *testing.T) {
|
|||
check.NoError(t, err)
|
||||
check.Number(t, currentVersion, maxVersion)
|
||||
|
||||
ch := make(chan []*provider.MigrationResult)
|
||||
ch := make(chan []*goose.MigrationResult)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < maxVersion; i++ {
|
||||
wg.Add(1)
|
||||
|
@ -435,8 +435,8 @@ func TestConcurrentProvider(t *testing.T) {
|
|||
close(ch)
|
||||
}()
|
||||
var (
|
||||
valid [][]*provider.MigrationResult
|
||||
empty [][]*provider.MigrationResult
|
||||
valid [][]*goose.MigrationResult
|
||||
empty [][]*goose.MigrationResult
|
||||
)
|
||||
for results := range ch {
|
||||
if len(results) == 0 {
|
||||
|
@ -486,9 +486,9 @@ func TestNoVersioning(t *testing.T) {
|
|||
// These are owners created by migration files.
|
||||
wantOwnerCount = 4
|
||||
)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
|
||||
provider.WithVerbose(testing.Verbose()),
|
||||
provider.WithDisabledVersioning(false), // This is the default.
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
|
||||
goose.WithVerbose(testing.Verbose()),
|
||||
goose.WithDisableVersioning(false), // This is the default.
|
||||
)
|
||||
check.Number(t, len(p.ListSources()), 3)
|
||||
check.NoError(t, err)
|
||||
|
@ -499,9 +499,9 @@ func TestNoVersioning(t *testing.T) {
|
|||
check.Number(t, baseVersion, 3)
|
||||
t.Run("seed-up-down-to-zero", func(t *testing.T) {
|
||||
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
|
||||
provider.WithVerbose(testing.Verbose()),
|
||||
provider.WithDisabledVersioning(true), // Provider with no versioning.
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
|
||||
goose.WithVerbose(testing.Verbose()),
|
||||
goose.WithDisableVersioning(true), // Provider with no versioning.
|
||||
)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(p.ListSources()), 2)
|
||||
|
@ -552,8 +552,8 @@ func TestAllowMissing(t *testing.T) {
|
|||
|
||||
t.Run("missing_now_allowed", func(t *testing.T) {
|
||||
db := newDB(t)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
|
||||
provider.WithAllowedMissing(false),
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
|
||||
goose.WithAllowOutofOrder(false),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
|
||||
|
@ -607,8 +607,8 @@ func TestAllowMissing(t *testing.T) {
|
|||
|
||||
t.Run("missing_allowed", func(t *testing.T) {
|
||||
db := newDB(t)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
|
||||
provider.WithAllowedMissing(true),
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
|
||||
goose.WithAllowOutofOrder(true),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
|
||||
|
@ -640,7 +640,7 @@ func TestAllowMissing(t *testing.T) {
|
|||
check.Bool(t, upResult != nil, true)
|
||||
check.Number(t, upResult.Source.Version, 6)
|
||||
|
||||
count, err := getGooseVersionCount(db, provider.DefaultTablename)
|
||||
count, err := getGooseVersionCount(db, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, count, 6)
|
||||
current, err := p.GetDBVersion(ctx)
|
||||
|
@ -676,7 +676,7 @@ func TestAllowMissing(t *testing.T) {
|
|||
testDownAndVersion(1, 1)
|
||||
_, err = p.Down(ctx)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true)
|
||||
check.Bool(t, errors.Is(err, goose.ErrNoNextVersion), true)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -691,6 +691,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
|
|||
}
|
||||
|
||||
func TestGoOnly(t *testing.T) {
|
||||
t.Cleanup(goose.ResetGlobalMigrations)
|
||||
// Not parallel because each subtest modifies global state.
|
||||
|
||||
countUser := func(db *sql.DB) int {
|
||||
|
@ -703,99 +704,109 @@ func TestGoOnly(t *testing.T) {
|
|||
|
||||
t.Run("with_tx", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
register := []*provider.MigrationCopy{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
||||
DownFnContext: newTxFn("DROP TABLE users"),
|
||||
},
|
||||
register := []*goose.Migration{
|
||||
goose.NewGoMigration(
|
||||
1,
|
||||
&goose.GoFunc{RunTx: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)")},
|
||||
&goose.GoFunc{RunTx: newTxFn("DROP TABLE users")},
|
||||
),
|
||||
}
|
||||
err := provider.SetGlobalGoMigrations(register)
|
||||
err := goose.SetGlobalMigrations(register...)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
t.Cleanup(goose.ResetGlobalMigrations)
|
||||
|
||||
db := newDB(t)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
|
||||
provider.WithGoMigration(
|
||||
register = []*goose.Migration{
|
||||
goose.NewGoMigration(
|
||||
2,
|
||||
&provider.GoMigrationFunc{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
||||
&provider.GoMigrationFunc{Run: newTxFn("DELETE FROM users")},
|
||||
&goose.GoFunc{RunTx: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
||||
&goose.GoFunc{RunTx: newTxFn("DELETE FROM users")},
|
||||
),
|
||||
}
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
|
||||
goose.WithGoMigrations(register...),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
sources := p.ListSources()
|
||||
check.Number(t, len(p.ListSources()), 2)
|
||||
assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1)
|
||||
assertSource(t, sources[1], provider.TypeGo, "", 2)
|
||||
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
||||
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
||||
// Apply migration 1
|
||||
res, err := p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
check.Bool(t, tableExists(t, db, "users"), true)
|
||||
// Apply migration 2
|
||||
res, err = p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
||||
check.Number(t, countUser(db), 3)
|
||||
// Rollback migration 2
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
// Rollback migration 1
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
||||
// Check table does not exist
|
||||
check.Bool(t, tableExists(t, db, "users"), false)
|
||||
})
|
||||
t.Run("with_db", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
register := []*provider.MigrationCopy{
|
||||
{
|
||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
||||
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
||||
DownFnNoTxContext: newDBFn("DROP TABLE users"),
|
||||
},
|
||||
register := []*goose.Migration{
|
||||
goose.NewGoMigration(
|
||||
1,
|
||||
&goose.GoFunc{
|
||||
RunDB: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
||||
},
|
||||
&goose.GoFunc{
|
||||
RunDB: newDBFn("DROP TABLE users"),
|
||||
},
|
||||
),
|
||||
}
|
||||
err := provider.SetGlobalGoMigrations(register)
|
||||
err := goose.SetGlobalMigrations(register...)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
||||
t.Cleanup(goose.ResetGlobalMigrations)
|
||||
|
||||
db := newDB(t)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
|
||||
provider.WithGoMigration(
|
||||
register = []*goose.Migration{
|
||||
goose.NewGoMigration(
|
||||
2,
|
||||
&provider.GoMigrationFunc{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
||||
&provider.GoMigrationFunc{RunNoTx: newDBFn("DELETE FROM users")},
|
||||
&goose.GoFunc{RunDB: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
||||
&goose.GoFunc{RunDB: newDBFn("DELETE FROM users")},
|
||||
),
|
||||
}
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
|
||||
goose.WithGoMigrations(register...),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
sources := p.ListSources()
|
||||
check.Number(t, len(p.ListSources()), 2)
|
||||
assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1)
|
||||
assertSource(t, sources[1], provider.TypeGo, "", 2)
|
||||
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
||||
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
||||
// Apply migration 1
|
||||
res, err := p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
check.Bool(t, tableExists(t, db, "users"), true)
|
||||
// Apply migration 2
|
||||
res, err = p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
||||
check.Number(t, countUser(db), 3)
|
||||
// Rollback migration 2
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
||||
check.Number(t, countUser(db), 0)
|
||||
// Rollback migration 1
|
||||
res, err = p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
|
||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
||||
// Check table does not exist
|
||||
check.Bool(t, tableExists(t, db, "users"), false)
|
||||
})
|
||||
|
@ -818,16 +829,23 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
|||
check.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
newProvider := func() *provider.Provider {
|
||||
sessionLocker, err := lock.NewPostgresSessionLocker()
|
||||
check.NoError(t, err)
|
||||
p, err := provider.NewProvider(database.DialectPostgres, db, os.DirFS("../../testdata/migrations"),
|
||||
provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
|
||||
provider.WithVerbose(testing.Verbose()),
|
||||
newProvider := func() *goose.Provider {
|
||||
|
||||
sessionLocker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockTimeout(5, 60), // Timeout 5min. Try every 5s up to 60 times.
|
||||
)
|
||||
check.NoError(t, err)
|
||||
p, err := goose.NewProvider(
|
||||
database.DialectPostgres,
|
||||
db,
|
||||
os.DirFS("testdata/migrations"),
|
||||
goose.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
|
||||
)
|
||||
check.NoError(t, err)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
provider1 := newProvider()
|
||||
provider2 := newProvider()
|
||||
|
||||
|
@ -891,7 +909,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
|||
for {
|
||||
result, err := provider1.UpByOne(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
@ -907,7 +925,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
|||
for {
|
||||
result, err := provider2.UpByOne(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
@ -993,7 +1011,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
|||
for {
|
||||
result, err := provider1.Down(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
@ -1009,7 +1027,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
|||
for {
|
||||
result, err := provider2.Down(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
||||
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
@ -1068,14 +1086,14 @@ func randomAlphaNumeric(length int) string {
|
|||
return string(b)
|
||||
}
|
||||
|
||||
func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) {
|
||||
func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provider, *sql.DB) {
|
||||
t.Helper()
|
||||
db := newDB(t)
|
||||
opts = append(
|
||||
opts,
|
||||
provider.WithVerbose(testing.Verbose()),
|
||||
goose.WithVerbose(testing.Verbose()),
|
||||
)
|
||||
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
|
||||
check.NoError(t, err)
|
||||
return p, db
|
||||
}
|
||||
|
@ -1118,14 +1136,14 @@ func getTableNames(db *sql.DB) ([]string, error) {
|
|||
return tables, nil
|
||||
}
|
||||
|
||||
func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) {
|
||||
func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) {
|
||||
t.Helper()
|
||||
check.Equal(t, got.State, state)
|
||||
check.Equal(t, got.Source, source)
|
||||
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
|
||||
}
|
||||
|
||||
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
|
||||
func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) {
|
||||
t.Helper()
|
||||
check.Bool(t, got != nil, true)
|
||||
check.Equal(t, got.Source, source)
|
||||
|
@ -1135,21 +1153,15 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S
|
|||
check.Bool(t, got.Duration > 0, true)
|
||||
}
|
||||
|
||||
func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) {
|
||||
func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) {
|
||||
t.Helper()
|
||||
check.Equal(t, got.Type, typ)
|
||||
check.Equal(t, got.Path, name)
|
||||
check.Equal(t, got.Version, version)
|
||||
switch got.Type {
|
||||
case provider.TypeGo:
|
||||
check.Equal(t, got.Type.String(), "go")
|
||||
case provider.TypeSQL:
|
||||
check.Equal(t, got.Type.String(), "sql")
|
||||
}
|
||||
}
|
||||
|
||||
func newSource(t provider.MigrationType, fullpath string, version int64) provider.Source {
|
||||
return provider.Source{
|
||||
func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source {
|
||||
return &goose.Source{
|
||||
Type: t,
|
||||
Path: fullpath,
|
||||
Version: version,
|
|
@ -0,0 +1,79 @@
|
|||
package goose_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||
check.NoError(t, err)
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
_, err := goose.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true)
|
||||
})
|
||||
|
||||
mapFS := fstest.MapFS{
|
||||
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
|
||||
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
||||
}
|
||||
fsys, err := fs.Sub(mapFS, "migrations")
|
||||
check.NoError(t, err)
|
||||
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys)
|
||||
check.NoError(t, err)
|
||||
sources := p.ListSources()
|
||||
check.Equal(t, len(sources), 2)
|
||||
check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1))
|
||||
check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2))
|
||||
|
||||
}
|
||||
|
||||
var (
|
||||
migration1 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE foo (id INTEGER PRIMARY KEY);
|
||||
-- +goose Down
|
||||
DROP TABLE foo;
|
||||
`
|
||||
migration2 = `
|
||||
-- +goose Up
|
||||
ALTER TABLE foo ADD COLUMN name TEXT;
|
||||
-- +goose Down
|
||||
ALTER TABLE foo DROP COLUMN name;
|
||||
`
|
||||
migration3 = `
|
||||
-- +goose Up
|
||||
CREATE TABLE bar (
|
||||
id INTEGER PRIMARY KEY,
|
||||
description TEXT
|
||||
);
|
||||
-- +goose Down
|
||||
DROP TABLE bar;
|
||||
`
|
||||
migration4 = `
|
||||
-- +goose Up
|
||||
-- Rename the 'foo' table to 'my_foo'
|
||||
ALTER TABLE foo RENAME TO my_foo;
|
||||
|
||||
-- Add a new column 'timestamp' to 'my_foo'
|
||||
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
|
||||
|
||||
-- +goose Down
|
||||
-- Remove the 'timestamp' column from 'my_foo'
|
||||
ALTER TABLE my_foo DROP COLUMN timestamp;
|
||||
|
||||
-- Rename the 'my_foo' table back to 'foo'
|
||||
ALTER TABLE my_foo RENAME TO foo;
|
||||
`
|
||||
)
|
|
@ -1,30 +1,15 @@
|
|||
package provider
|
||||
package goose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
import "time"
|
||||
|
||||
// MigrationType is the type of migration.
|
||||
type MigrationType int
|
||||
type MigrationType string
|
||||
|
||||
const (
|
||||
TypeGo MigrationType = iota + 1
|
||||
TypeSQL
|
||||
TypeGo MigrationType = "go"
|
||||
TypeSQL MigrationType = "sql"
|
||||
)
|
||||
|
||||
func (t MigrationType) String() string {
|
||||
switch t {
|
||||
case TypeGo:
|
||||
return "go"
|
||||
case TypeSQL:
|
||||
return "sql"
|
||||
default:
|
||||
// This should never happen.
|
||||
return fmt.Sprintf("unknown (%d)", t)
|
||||
}
|
||||
}
|
||||
|
||||
// Source represents a single migration source.
|
||||
//
|
||||
// The Path field may be empty if the migration was registered manually. This is typically the case
|
||||
|
@ -37,7 +22,7 @@ type Source struct {
|
|||
|
||||
// MigrationResult is the result of a single migration operation.
|
||||
type MigrationResult struct {
|
||||
Source Source
|
||||
Source *Source
|
||||
Duration time.Duration
|
||||
Direction string
|
||||
// Empty indicates no action was taken during the migration, but it was still versioned. For
|
||||
|
@ -64,7 +49,7 @@ const (
|
|||
|
||||
// MigrationStatus represents the status of a single migration.
|
||||
type MigrationStatus struct {
|
||||
Source Source
|
||||
Source *Source
|
||||
State State
|
||||
AppliedAt time.Time
|
||||
}
|
|
@ -66,7 +66,7 @@ func register(filename string, useTx bool, up, down *GoFunc) error {
|
|||
// We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but
|
||||
// we know based on the register function what the user is requesting.
|
||||
m.UseTx = useTx
|
||||
registeredGoMigrations[v] = &m
|
||||
registeredGoMigrations[v] = m
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
17
types.go
17
types.go
|
@ -1,17 +0,0 @@
|
|||
package goose
|
||||
|
||||
// MigrationType is the type of migration.
|
||||
type MigrationType string
|
||||
|
||||
const (
|
||||
TypeGo MigrationType = "go"
|
||||
TypeSQL MigrationType = "sql"
|
||||
)
|
||||
|
||||
func (t MigrationType) String() string {
|
||||
// This should never happen.
|
||||
if t == "" {
|
||||
return "unknown migration type"
|
||||
}
|
||||
return string(t)
|
||||
}
|
Loading…
Reference in New Issue