mirror of https://github.com/pressly/goose.git
feat: Add goose provider (#635)
parent
8503d4e20b
commit
04e12b88f4
|
@ -2,6 +2,7 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -100,6 +101,9 @@ func (s *store) GetMigration(
|
||||||
&result.Timestamp,
|
&result.Timestamp,
|
||||||
&result.IsApplied,
|
&result.IsApplied,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
|
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
|
||||||
}
|
}
|
||||||
return &result, nil
|
return &result, nil
|
||||||
|
|
|
@ -2,9 +2,15 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
|
||||||
|
ErrVersionNotFound = errors.New("version not found")
|
||||||
|
)
|
||||||
|
|
||||||
// Store is an interface that defines methods for managing database migrations and versioning. By
|
// Store is an interface that defines methods for managing database migrations and versioning. By
|
||||||
// defining a Store interface, we can support multiple databases with consistent functionality.
|
// defining a Store interface, we can support multiple databases with consistent functionality.
|
||||||
//
|
//
|
||||||
|
@ -24,8 +30,8 @@ type Store interface {
|
||||||
// Delete deletes a version id from the version table.
|
// Delete deletes a version id from the version table.
|
||||||
Delete(ctx context.Context, db DBTxConn, version int64) error
|
Delete(ctx context.Context, db DBTxConn, version int64) error
|
||||||
|
|
||||||
// GetMigration retrieves a single migration by version id. This method may return the raw sql
|
// GetMigration retrieves a single migration by version id. If the query succeeds, but the
|
||||||
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
|
// version is not found, this method must return [ErrVersionNotFound].
|
||||||
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
|
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
|
||||||
|
|
||||||
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
|
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
|
||||||
|
|
|
@ -205,7 +205,7 @@ func testStore(
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
_, err := store.GetMigration(ctx, conn, 0)
|
_, err := store.GetMigration(ctx, conn, 0)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
|
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
|
48
globals.go
48
globals.go
|
@ -22,13 +22,12 @@ func ResetGlobalMigrations() {
|
||||||
// [NewGoMigration] function.
|
// [NewGoMigration] function.
|
||||||
//
|
//
|
||||||
// Not safe for concurrent use.
|
// Not safe for concurrent use.
|
||||||
func SetGlobalMigrations(migrations ...Migration) error {
|
func SetGlobalMigrations(migrations ...*Migration) error {
|
||||||
for _, migration := range migrations {
|
for _, m := range migrations {
|
||||||
m := &migration
|
|
||||||
if _, ok := registeredGoMigrations[m.Version]; ok {
|
if _, ok := registeredGoMigrations[m.Version]; ok {
|
||||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||||
}
|
}
|
||||||
if err := checkMigration(m); err != nil {
|
if err := checkGoMigration(m); err != nil {
|
||||||
return fmt.Errorf("invalid go migration: %w", err)
|
return fmt.Errorf("invalid go migration: %w", err)
|
||||||
}
|
}
|
||||||
registeredGoMigrations[m.Version] = m
|
registeredGoMigrations[m.Version] = m
|
||||||
|
@ -36,7 +35,7 @@ func SetGlobalMigrations(migrations ...Migration) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkMigration(m *Migration) error {
|
func checkGoMigration(m *Migration) error {
|
||||||
if !m.construct {
|
if !m.construct {
|
||||||
return errors.New("must use NewGoMigration to construct migrations")
|
return errors.New("must use NewGoMigration to construct migrations")
|
||||||
}
|
}
|
||||||
|
@ -63,10 +62,10 @@ func checkMigration(m *Migration) error {
|
||||||
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
|
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := setGoFunc(m.goUp); err != nil {
|
if err := checkGoFunc(m.goUp); err != nil {
|
||||||
return fmt.Errorf("up function: %w", err)
|
return fmt.Errorf("up function: %w", err)
|
||||||
}
|
}
|
||||||
if err := setGoFunc(m.goDown); err != nil {
|
if err := checkGoFunc(m.goDown); err != nil {
|
||||||
return fmt.Errorf("down function: %w", err)
|
return fmt.Errorf("down function: %w", err)
|
||||||
}
|
}
|
||||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||||
|
@ -84,47 +83,22 @@ func checkMigration(m *Migration) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setGoFunc(f *GoFunc) error {
|
func checkGoFunc(f *GoFunc) error {
|
||||||
if f == nil {
|
|
||||||
f = &GoFunc{Mode: TransactionEnabled}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if f.RunTx != nil && f.RunDB != nil {
|
if f.RunTx != nil && f.RunDB != nil {
|
||||||
return errors.New("must specify exactly one of RunTx or RunDB")
|
return errors.New("must specify exactly one of RunTx or RunDB")
|
||||||
}
|
}
|
||||||
if f.RunTx == nil && f.RunDB == nil {
|
|
||||||
switch f.Mode {
|
switch f.Mode {
|
||||||
case 0:
|
|
||||||
// Default to TransactionEnabled ONLY if mode is not set explicitly.
|
|
||||||
f.Mode = TransactionEnabled
|
|
||||||
case TransactionEnabled, TransactionDisabled:
|
case TransactionEnabled, TransactionDisabled:
|
||||||
// No functions but mode is set. This is not an error. It means the user wants to record
|
// No functions, but mode is set. This is not an error. It means the user wants to
|
||||||
// a version with the given mode but not run any functions.
|
// record a version with the given mode but not run any functions.
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid mode: %d", f.Mode)
|
return fmt.Errorf("invalid mode: %d", f.Mode)
|
||||||
}
|
}
|
||||||
return nil
|
if f.RunDB != nil && f.Mode != TransactionDisabled {
|
||||||
}
|
|
||||||
if f.RunDB != nil {
|
|
||||||
switch f.Mode {
|
|
||||||
case 0, TransactionDisabled:
|
|
||||||
f.Mode = TransactionDisabled
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
|
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
|
||||||
}
|
}
|
||||||
}
|
if f.RunTx != nil && f.Mode != TransactionEnabled {
|
||||||
if f.RunTx != nil {
|
|
||||||
switch f.Mode {
|
|
||||||
case 0, TransactionEnabled:
|
|
||||||
f.Mode = TransactionEnabled
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
|
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
|
|
||||||
// the functions or return an error. This should never happen.
|
|
||||||
if f.Mode == 0 {
|
|
||||||
return errors.New("failed to infer transaction mode")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) {
|
||||||
// reset so we can check the default is set
|
// reset so we can check the default is set
|
||||||
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
|
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
|
||||||
err = SetGlobalMigrations(migration2)
|
err = SetGlobalMigrations(migration2)
|
||||||
check.NoError(t, err)
|
check.HasError(t, err)
|
||||||
check.Number(t, len(registeredGoMigrations), 2)
|
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")
|
||||||
registered = registeredGoMigrations[2]
|
|
||||||
check.Bool(t, registered.goUp != nil, true)
|
migration3 := NewGoMigration(3, nil, nil)
|
||||||
check.Bool(t, registered.goDown != nil, true)
|
// reset so we can check the default is set
|
||||||
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
|
migration3.goDown.Mode = 0
|
||||||
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
|
err = SetGlobalMigrations(migration3)
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
|
||||||
})
|
})
|
||||||
t.Run("unknown_mode", func(t *testing.T) {
|
t.Run("unknown_mode", func(t *testing.T) {
|
||||||
m := NewGoMigration(1, nil, nil)
|
m := NewGoMigration(1, nil, nil)
|
||||||
|
@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) {
|
||||||
runTx := func(context.Context, *sql.Tx) error { return nil }
|
runTx := func(context.Context, *sql.Tx) error { return nil }
|
||||||
|
|
||||||
// Success.
|
// Success.
|
||||||
err := SetGlobalMigrations([]Migration{}...)
|
err := SetGlobalMigrations([]*Migration{}...)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
||||||
|
@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) {
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "go migration with version 1 already registered")
|
check.Contains(t, err.Error(), "go migration with version 1 already registered")
|
||||||
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
|
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckMigration(t *testing.T) {
|
func TestCheckMigration(t *testing.T) {
|
||||||
// Failures.
|
|
||||||
err := checkMigration(&Migration{})
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
|
||||||
err = checkMigration(&Migration{construct: true})
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "must be registered")
|
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true})
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), `type must be "go"`)
|
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "version must be greater than zero")
|
|
||||||
// Success.
|
// Success.
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
|
err := checkGoMigration(NewGoMigration(1, nil, nil))
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
// Failures.
|
// Failures.
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
err = checkGoMigration(&Migration{})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||||
|
err = checkGoMigration(&Migration{construct: true})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), "must be registered")
|
||||||
|
err = checkGoMigration(&Migration{construct: true, Registered: true})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), `type must be "go"`)
|
||||||
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), "version must be greater than zero")
|
||||||
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), "up function: invalid mode: 0")
|
||||||
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Contains(t, err.Error(), "down function: invalid mode: 0")
|
||||||
|
// Success.
|
||||||
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
|
||||||
|
check.NoError(t, err)
|
||||||
|
// Failures.
|
||||||
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
|
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), `no filename separator '_' found`)
|
check.Contains(t, err.Error(), `no filename separator '_' found`)
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
|
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
|
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||||
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||||
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||||
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||||
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
UpFn: func(*sql.Tx) error { return nil },
|
UpFn: func(*sql.Tx) error { return nil },
|
||||||
UpFnNoTx: func(*sql.DB) error { return nil },
|
UpFnNoTx: func(*sql.DB) error { return nil },
|
||||||
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
|
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
|
||||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
DownFn: func(*sql.Tx) error { return nil },
|
DownFn: func(*sql.Tx) error { return nil },
|
||||||
DownFnNoTx: func(*sql.DB) error { return nil },
|
DownFnNoTx: func(*sql.DB) error { return nil },
|
||||||
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
|
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
|
||||||
|
|
|
@ -1,186 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
type migration struct {
|
|
||||||
Source Source
|
|
||||||
// A migration is either a Go migration or a SQL migration, but never both.
|
|
||||||
//
|
|
||||||
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
|
|
||||||
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
|
|
||||||
// majority of the time migrations are incremental, so it is likely that the user will only want
|
|
||||||
// to run the last few migrations, and there is no need to parse ALL prior migrations.
|
|
||||||
//
|
|
||||||
// Exactly one of these fields will be set:
|
|
||||||
Go *goMigration
|
|
||||||
// -- OR --
|
|
||||||
SQL *sqlMigration
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *migration) useTx(direction bool) bool {
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeSQL:
|
|
||||||
return m.SQL.UseTx
|
|
||||||
case TypeGo:
|
|
||||||
if m.Go == nil || m.Go.isEmpty(direction) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if direction {
|
|
||||||
return m.Go.up.Run != nil
|
|
||||||
}
|
|
||||||
return m.Go.down.Run != nil
|
|
||||||
}
|
|
||||||
// This should never happen.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *migration) isEmpty(direction bool) bool {
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeSQL:
|
|
||||||
return m.SQL == nil || m.SQL.isEmpty(direction)
|
|
||||||
case TypeGo:
|
|
||||||
return m.Go == nil || m.Go.isEmpty(direction)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *migration) filename() string {
|
|
||||||
return filepath.Base(m.Source.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// run runs the migration inside of a transaction.
|
|
||||||
func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeSQL:
|
|
||||||
if m.SQL == nil {
|
|
||||||
return fmt.Errorf("tx: sql migration has not been parsed")
|
|
||||||
}
|
|
||||||
return m.SQL.run(ctx, tx, direction)
|
|
||||||
case TypeGo:
|
|
||||||
return m.Go.run(ctx, tx, direction)
|
|
||||||
}
|
|
||||||
// This should never happen.
|
|
||||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
|
||||||
}
|
|
||||||
|
|
||||||
// runNoTx runs the migration without a transaction.
|
|
||||||
func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeSQL:
|
|
||||||
if m.SQL == nil {
|
|
||||||
return fmt.Errorf("db: sql migration has not been parsed")
|
|
||||||
}
|
|
||||||
return m.SQL.run(ctx, db, direction)
|
|
||||||
case TypeGo:
|
|
||||||
return m.Go.runNoTx(ctx, db, direction)
|
|
||||||
}
|
|
||||||
// This should never happen.
|
|
||||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
|
||||||
}
|
|
||||||
|
|
||||||
// runConn runs the migration without a transaction using the provided connection.
|
|
||||||
func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error {
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeSQL:
|
|
||||||
if m.SQL == nil {
|
|
||||||
return fmt.Errorf("conn: sql migration has not been parsed")
|
|
||||||
}
|
|
||||||
return m.SQL.run(ctx, conn, direction)
|
|
||||||
case TypeGo:
|
|
||||||
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
|
|
||||||
}
|
|
||||||
// This should never happen.
|
|
||||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
|
|
||||||
}
|
|
||||||
|
|
||||||
type goMigration struct {
|
|
||||||
fullpath string
|
|
||||||
up, down *GoMigrationFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *goMigration) isEmpty(direction bool) bool {
|
|
||||||
if g.up == nil && g.down == nil {
|
|
||||||
panic("go migration has no up or down")
|
|
||||||
}
|
|
||||||
if direction {
|
|
||||||
return g.up == nil
|
|
||||||
}
|
|
||||||
return g.down == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration {
|
|
||||||
return &goMigration{
|
|
||||||
fullpath: fullpath,
|
|
||||||
up: up,
|
|
||||||
down: down,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
|
||||||
if g == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var fn func(context.Context, *sql.Tx) error
|
|
||||||
if direction && g.up != nil {
|
|
||||||
fn = g.up.Run
|
|
||||||
}
|
|
||||||
if !direction && g.down != nil {
|
|
||||||
fn = g.down.Run
|
|
||||||
}
|
|
||||||
if fn != nil {
|
|
||||||
return fn(ctx, tx)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
|
||||||
if g == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var fn func(context.Context, *sql.DB) error
|
|
||||||
if direction && g.up != nil {
|
|
||||||
fn = g.up.RunNoTx
|
|
||||||
}
|
|
||||||
if !direction && g.down != nil {
|
|
||||||
fn = g.down.RunNoTx
|
|
||||||
}
|
|
||||||
if fn != nil {
|
|
||||||
return fn(ctx, db)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type sqlMigration struct {
|
|
||||||
UseTx bool
|
|
||||||
UpStatements []string
|
|
||||||
DownStatements []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlMigration) isEmpty(direction bool) bool {
|
|
||||||
if direction {
|
|
||||||
return len(s.UpStatements) == 0
|
|
||||||
}
|
|
||||||
return len(s.DownStatements) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error {
|
|
||||||
var statements []string
|
|
||||||
if direction {
|
|
||||||
statements = s.UpStatements
|
|
||||||
} else {
|
|
||||||
statements = s.DownStatements
|
|
||||||
}
|
|
||||||
for _, stmt := range statements {
|
|
||||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MigrationCopy struct {
|
|
||||||
Version int64
|
|
||||||
Source string // path to .sql script or go file
|
|
||||||
Registered bool
|
|
||||||
UpFnContext, DownFnContext func(context.Context, *sql.Tx) error
|
|
||||||
UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
var registeredGoMigrations = make(map[int64]*MigrationCopy)
|
|
||||||
|
|
||||||
// SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of
|
|
||||||
// the migrations are nil or if a migration with the same version has already been registered.
|
|
||||||
//
|
|
||||||
// Not safe for concurrent use.
|
|
||||||
func SetGlobalGoMigrations(migrations []*MigrationCopy) error {
|
|
||||||
for _, m := range migrations {
|
|
||||||
if m == nil {
|
|
||||||
return errors.New("cannot register nil go migration")
|
|
||||||
}
|
|
||||||
if m.Version < 1 {
|
|
||||||
return errors.New("migration versions must be greater than zero")
|
|
||||||
}
|
|
||||||
if !m.Registered {
|
|
||||||
return errors.New("migration must be registered")
|
|
||||||
}
|
|
||||||
if m.Source != "" {
|
|
||||||
// If the source is set, expect it to be a file path with a numeric component that
|
|
||||||
// matches the version.
|
|
||||||
version, err := goose.NumericComponent(m.Source)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if version != m.Version {
|
|
||||||
return fmt.Errorf("migration version %d does not match source %q", m.Version, m.Source)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// It's valid for all of these to be nil.
|
|
||||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
|
||||||
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
|
|
||||||
}
|
|
||||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
|
||||||
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
|
|
||||||
}
|
|
||||||
if _, ok := registeredGoMigrations[m.Version]; ok {
|
|
||||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
|
||||||
}
|
|
||||||
registeredGoMigrations[m.Version] = m
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetGlobalGoMigrations resets the global go migrations registry.
|
|
||||||
//
|
|
||||||
// Not safe for concurrent use.
|
|
||||||
func ResetGlobalGoMigrations() {
|
|
||||||
registeredGoMigrations = make(map[int64]*MigrationCopy)
|
|
||||||
}
|
|
|
@ -1,272 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"math"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Provider is a goose migration provider.
|
|
||||||
type Provider struct {
|
|
||||||
// mu protects all accesses to the provider and must be held when calling operations on the
|
|
||||||
// database.
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
db *sql.DB
|
|
||||||
fsys fs.FS
|
|
||||||
cfg config
|
|
||||||
store database.Store
|
|
||||||
|
|
||||||
// migrations are ordered by version in ascending order.
|
|
||||||
migrations []*migration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewProvider returns a new goose Provider.
|
|
||||||
//
|
|
||||||
// The caller is responsible for matching the database dialect with the database/sql driver. For
|
|
||||||
// example, if the database dialect is "postgres", the database/sql driver could be
|
|
||||||
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
|
|
||||||
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
|
|
||||||
// as using a custom table name or supplying a custom store implementation, see [WithStore].
|
|
||||||
//
|
|
||||||
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
|
|
||||||
// use [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
|
|
||||||
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
|
|
||||||
// migrations using [fs.Sub].
|
|
||||||
//
|
|
||||||
// See [ProviderOption] for more information on configuring the provider.
|
|
||||||
//
|
|
||||||
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
|
|
||||||
//
|
|
||||||
// Experimental: This API is experimental and may change in the future.
|
|
||||||
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
|
|
||||||
if db == nil {
|
|
||||||
return nil, errors.New("db must not be nil")
|
|
||||||
}
|
|
||||||
if fsys == nil {
|
|
||||||
fsys = noopFS{}
|
|
||||||
}
|
|
||||||
cfg := config{
|
|
||||||
registered: make(map[int64]*goMigration),
|
|
||||||
excludes: make(map[string]bool),
|
|
||||||
}
|
|
||||||
for _, opt := range opts {
|
|
||||||
if err := opt.apply(&cfg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Allow users to specify a custom store implementation, but only if they don't specify a
|
|
||||||
// dialect. If they specify a dialect, we'll use the default store implementation.
|
|
||||||
if dialect == "" && cfg.store == nil {
|
|
||||||
return nil, errors.New("dialect must not be empty")
|
|
||||||
}
|
|
||||||
if dialect != "" && cfg.store != nil {
|
|
||||||
return nil, errors.New("cannot set both dialect and custom store")
|
|
||||||
}
|
|
||||||
var store database.Store
|
|
||||||
if dialect != "" {
|
|
||||||
var err error
|
|
||||||
store, err = database.NewStore(dialect, DefaultTablename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
store = cfg.store
|
|
||||||
}
|
|
||||||
if store.Tablename() == "" {
|
|
||||||
return nil, errors.New("invalid store implementation: table name must not be empty")
|
|
||||||
}
|
|
||||||
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newProvider(
|
|
||||||
db *sql.DB,
|
|
||||||
store database.Store,
|
|
||||||
fsys fs.FS,
|
|
||||||
cfg config,
|
|
||||||
global map[int64]*MigrationCopy,
|
|
||||||
) (*Provider, error) {
|
|
||||||
// Collect migrations from the filesystem and merge with registered migrations.
|
|
||||||
//
|
|
||||||
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
|
|
||||||
// lazily.
|
|
||||||
//
|
|
||||||
// TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to
|
|
||||||
// return an error if there are any SQL parsing errors. This adds a bit overhead to startup
|
|
||||||
// though, so we should make it optional.
|
|
||||||
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
registered := make(map[int64]*goMigration)
|
|
||||||
// Add user-registered Go migrations.
|
|
||||||
for version, m := range cfg.registered {
|
|
||||||
registered[version] = newGoMigration("", m.up, m.down)
|
|
||||||
}
|
|
||||||
// Add init() functions. This is a bit ugly because we need to convert from the old Migration
|
|
||||||
// struct to the new goMigration struct.
|
|
||||||
for version, m := range global {
|
|
||||||
if _, ok := registered[version]; ok {
|
|
||||||
return nil, fmt.Errorf("go migration with version %d already registered", version)
|
|
||||||
}
|
|
||||||
if m == nil {
|
|
||||||
return nil, errors.New("registered migration with nil init function")
|
|
||||||
}
|
|
||||||
g := newGoMigration(m.Source, nil, nil)
|
|
||||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
|
||||||
return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext")
|
|
||||||
}
|
|
||||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
|
||||||
return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext")
|
|
||||||
}
|
|
||||||
// Up
|
|
||||||
if m.UpFnContext != nil {
|
|
||||||
g.up = &GoMigrationFunc{
|
|
||||||
Run: m.UpFnContext,
|
|
||||||
}
|
|
||||||
} else if m.UpFnNoTxContext != nil {
|
|
||||||
g.up = &GoMigrationFunc{
|
|
||||||
RunNoTx: m.UpFnNoTxContext,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Down
|
|
||||||
if m.DownFnContext != nil {
|
|
||||||
g.down = &GoMigrationFunc{
|
|
||||||
Run: m.DownFnContext,
|
|
||||||
}
|
|
||||||
} else if m.DownFnNoTxContext != nil {
|
|
||||||
g.down = &GoMigrationFunc{
|
|
||||||
RunNoTx: m.DownFnNoTxContext,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
registered[version] = g
|
|
||||||
}
|
|
||||||
migrations, err := merge(filesystemSources, registered)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(migrations) == 0 {
|
|
||||||
return nil, ErrNoMigrations
|
|
||||||
}
|
|
||||||
return &Provider{
|
|
||||||
db: db,
|
|
||||||
fsys: fsys,
|
|
||||||
cfg: cfg,
|
|
||||||
store: store,
|
|
||||||
migrations: migrations,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status returns the status of all migrations, merging the list of migrations from the database and
|
|
||||||
// filesystem. The returned items are ordered by version, in ascending order.
|
|
||||||
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
|
||||||
return p.status(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDBVersion returns the max version from the database, regardless of the applied order. For
|
|
||||||
// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been
|
|
||||||
// applied, it returns 0.
|
|
||||||
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
|
||||||
return p.getDBVersion(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListSources returns a list of all available migration sources the provider is aware of, sorted in
|
|
||||||
// ascending order by version.
|
|
||||||
func (p *Provider) ListSources() []Source {
|
|
||||||
sources := make([]Source, 0, len(p.migrations))
|
|
||||||
for _, m := range p.migrations {
|
|
||||||
sources = append(sources, m.Source)
|
|
||||||
}
|
|
||||||
return sources
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ping attempts to ping the database to verify a connection is available.
|
|
||||||
func (p *Provider) Ping(ctx context.Context) error {
|
|
||||||
return p.db.PingContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the database connection.
|
|
||||||
func (p *Provider) Close() error {
|
|
||||||
return p.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyVersion applies exactly one migration by version. If there is no source for the specified
|
|
||||||
// version, this method returns [ErrVersionNotFound]. If the migration has been applied already,
|
|
||||||
// this method returns [ErrAlreadyApplied].
|
|
||||||
//
|
|
||||||
// When direction is true, the up migration is executed, and when direction is false, the down
|
|
||||||
// migration is executed.
|
|
||||||
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
|
|
||||||
if version < 1 {
|
|
||||||
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
|
|
||||||
}
|
|
||||||
return p.apply(ctx, version, direction)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Up applies all [StatePending] migrations. If there are no new migrations to apply, this method
|
|
||||||
// returns empty list and nil error.
|
|
||||||
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
|
|
||||||
return p.up(ctx, false, math.MaxInt64)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpByOne applies the next available migration. If there are no migrations to apply, this method
|
|
||||||
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
|
|
||||||
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
|
||||||
res, err := p.up(ctx, true, math.MaxInt64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res) == 0 {
|
|
||||||
return nil, ErrNoNextVersion
|
|
||||||
}
|
|
||||||
// This should never happen. We should always have exactly one result and test for this.
|
|
||||||
if len(res) > 1 {
|
|
||||||
return nil, fmt.Errorf("unexpected number of migrations returned running up-by-one: %d", len(res))
|
|
||||||
}
|
|
||||||
return res[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpTo applies all available migrations up to, and including, the specified version. If there are
|
|
||||||
// no migrations to apply, this method returns empty list and nil error.
|
|
||||||
//
|
|
||||||
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
|
|
||||||
// with a requested version of 10, only versions 9,10 will be applied.
|
|
||||||
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
|
||||||
if version < 1 {
|
|
||||||
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
|
|
||||||
}
|
|
||||||
return p.up(ctx, false, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Down rolls back the most recently applied migration. If there are no migrations to apply, this
|
|
||||||
// method returns [ErrNoNextVersion].
|
|
||||||
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
|
||||||
res, err := p.down(ctx, true, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res) == 0 {
|
|
||||||
return nil, ErrNoNextVersion
|
|
||||||
}
|
|
||||||
if len(res) > 1 {
|
|
||||||
return nil, fmt.Errorf("unexpected number of migrations returned running down: %d", len(res))
|
|
||||||
}
|
|
||||||
return res[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DownTo rolls back all migrations down to, but not including, the specified version.
|
|
||||||
//
|
|
||||||
// For instance, if the current database version is 11,10,9... and the requested version is 9, only
|
|
||||||
// migrations 11, 10 will be rolled back.
|
|
||||||
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
|
||||||
if version < 0 {
|
|
||||||
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
|
|
||||||
}
|
|
||||||
return p.down(ctx, false, version)
|
|
||||||
}
|
|
|
@ -1,153 +0,0 @@
|
||||||
package provider_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"io/fs"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
"testing/fstest"
|
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/database"
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
|
||||||
"github.com/pressly/goose/v3/internal/provider"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestProvider(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
|
||||||
check.NoError(t, err)
|
|
||||||
t.Run("empty", func(t *testing.T) {
|
|
||||||
_, err := provider.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
|
|
||||||
})
|
|
||||||
|
|
||||||
mapFS := fstest.MapFS{
|
|
||||||
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
|
|
||||||
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
|
||||||
}
|
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
|
||||||
check.NoError(t, err)
|
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys)
|
|
||||||
check.NoError(t, err)
|
|
||||||
sources := p.ListSources()
|
|
||||||
check.Equal(t, len(sources), 2)
|
|
||||||
check.Equal(t, sources[0], newSource(provider.TypeSQL, "001_foo.sql", 1))
|
|
||||||
check.Equal(t, sources[1], newSource(provider.TypeSQL, "002_bar.sql", 2))
|
|
||||||
|
|
||||||
t.Run("duplicate_go", func(t *testing.T) {
|
|
||||||
// Not parallel because it modifies global state.
|
|
||||||
register := []*provider.MigrationCopy{
|
|
||||||
{
|
|
||||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
|
||||||
UpFnContext: nil,
|
|
||||||
DownFnContext: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err := provider.SetGlobalGoMigrations(register)
|
|
||||||
check.NoError(t, err)
|
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
|
||||||
|
|
||||||
db := newDB(t)
|
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
|
|
||||||
provider.WithGoMigration(1, nil, nil),
|
|
||||||
)
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Equal(t, err.Error(), "go migration with version 1 already registered")
|
|
||||||
})
|
|
||||||
t.Run("empty_go", func(t *testing.T) {
|
|
||||||
db := newDB(t)
|
|
||||||
// explicit
|
|
||||||
_, err := provider.NewProvider(database.DialectSQLite3, db, nil,
|
|
||||||
provider.WithGoMigration(1, &provider.GoMigrationFunc{Run: nil}, &provider.GoMigrationFunc{Run: nil}),
|
|
||||||
)
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "go migration with version 1 must have an up function")
|
|
||||||
})
|
|
||||||
t.Run("duplicate_up", func(t *testing.T) {
|
|
||||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
|
||||||
{
|
|
||||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
|
||||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
|
||||||
UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
|
|
||||||
},
|
|
||||||
})
|
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
|
||||||
})
|
|
||||||
t.Run("duplicate_down", func(t *testing.T) {
|
|
||||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
|
||||||
{
|
|
||||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
|
||||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
|
||||||
DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil },
|
|
||||||
},
|
|
||||||
})
|
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
|
||||||
})
|
|
||||||
t.Run("not_registered", func(t *testing.T) {
|
|
||||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
|
||||||
{
|
|
||||||
Version: 1, Source: "00001_users_table.go",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "migration must be registered")
|
|
||||||
})
|
|
||||||
t.Run("zero_not_allowed", func(t *testing.T) {
|
|
||||||
err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{
|
|
||||||
{
|
|
||||||
Version: 0,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
|
||||||
check.HasError(t, err)
|
|
||||||
check.Contains(t, err.Error(), "migration versions must be greater than zero")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
migration1 = `
|
|
||||||
-- +goose Up
|
|
||||||
CREATE TABLE foo (id INTEGER PRIMARY KEY);
|
|
||||||
-- +goose Down
|
|
||||||
DROP TABLE foo;
|
|
||||||
`
|
|
||||||
migration2 = `
|
|
||||||
-- +goose Up
|
|
||||||
ALTER TABLE foo ADD COLUMN name TEXT;
|
|
||||||
-- +goose Down
|
|
||||||
ALTER TABLE foo DROP COLUMN name;
|
|
||||||
`
|
|
||||||
migration3 = `
|
|
||||||
-- +goose Up
|
|
||||||
CREATE TABLE bar (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
description TEXT
|
|
||||||
);
|
|
||||||
-- +goose Down
|
|
||||||
DROP TABLE bar;
|
|
||||||
`
|
|
||||||
migration4 = `
|
|
||||||
-- +goose Up
|
|
||||||
-- Rename the 'foo' table to 'my_foo'
|
|
||||||
ALTER TABLE foo RENAME TO my_foo;
|
|
||||||
|
|
||||||
-- Add a new column 'timestamp' to 'my_foo'
|
|
||||||
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
|
|
||||||
|
|
||||||
-- +goose Down
|
|
||||||
-- Remove the 'timestamp' column from 'my_foo'
|
|
||||||
ALTER TABLE my_foo DROP COLUMN timestamp;
|
|
||||||
|
|
||||||
-- Rename the 'my_foo' table back to 'foo'
|
|
||||||
ALTER TABLE my_foo RENAME TO foo;
|
|
||||||
`
|
|
||||||
)
|
|
|
@ -1,516 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/database"
|
|
||||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
|
||||||
"go.uber.org/multierr"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errMissingZeroVersion = errors.New("missing zero version migration")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
|
||||||
if version < 1 {
|
|
||||||
return nil, errors.New("version must be greater than zero")
|
|
||||||
}
|
|
||||||
conn, cleanup, err := p.initialize(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
retErr = multierr.Append(retErr, cleanup())
|
|
||||||
}()
|
|
||||||
if len(p.migrations) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
var apply []*migration
|
|
||||||
if p.cfg.disableVersioning {
|
|
||||||
apply = p.migrations
|
|
||||||
} else {
|
|
||||||
// optimize(mf): Listing all migrations from the database isn't great. This is only required
|
|
||||||
// to support the allow missing (out-of-order) feature. For users that don't use this
|
|
||||||
// feature, we could just query the database for the current max version and then apply
|
|
||||||
// migrations greater than that version.
|
|
||||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(dbMigrations) == 0 {
|
|
||||||
return nil, errMissingZeroVersion
|
|
||||||
}
|
|
||||||
apply, err = p.resolveUpMigrations(dbMigrations, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
|
|
||||||
// transaction. The default is to apply each migration sequentially on its own.
|
|
||||||
// https://github.com/pressly/goose/issues/222
|
|
||||||
//
|
|
||||||
// Careful, we can't use a single transaction for all migrations because some may have to be run
|
|
||||||
// in their own transaction.
|
|
||||||
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) resolveUpMigrations(
|
|
||||||
dbVersions []*database.ListMigrationsResult,
|
|
||||||
version int64,
|
|
||||||
) ([]*migration, error) {
|
|
||||||
var apply []*migration
|
|
||||||
var dbMaxVersion int64
|
|
||||||
// dbAppliedVersions is a map of all applied migrations in the database.
|
|
||||||
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
|
|
||||||
for _, m := range dbVersions {
|
|
||||||
dbAppliedVersions[m.Version] = true
|
|
||||||
if m.Version > dbMaxVersion {
|
|
||||||
dbMaxVersion = m.Version
|
|
||||||
}
|
|
||||||
}
|
|
||||||
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
|
|
||||||
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
|
|
||||||
// migrations entirely. At the moment this is not supported, but leaving this comment because
|
|
||||||
// that's where that logic would be handled.
|
|
||||||
//
|
|
||||||
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
|
|
||||||
// sure if this is a common use case, but it's possible.
|
|
||||||
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
|
||||||
var collected []string
|
|
||||||
for _, v := range missingMigrations {
|
|
||||||
collected = append(collected, v.filename)
|
|
||||||
}
|
|
||||||
msg := "migration"
|
|
||||||
if len(collected) > 1 {
|
|
||||||
msg += "s"
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
|
|
||||||
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for _, v := range missingMigrations {
|
|
||||||
m, err := p.getMigration(v.versionID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
apply = append(apply, m)
|
|
||||||
}
|
|
||||||
// filter all migrations with a version greater than the supplied version (min) and less than or
|
|
||||||
// equal to the requested version (max). Skip any migrations that have already been applied.
|
|
||||||
for _, m := range p.migrations {
|
|
||||||
if dbAppliedVersions[m.Source.Version] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
|
|
||||||
apply = append(apply, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return apply, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
|
|
||||||
conn, cleanup, err := p.initialize(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
retErr = multierr.Append(retErr, cleanup())
|
|
||||||
}()
|
|
||||||
if len(p.migrations) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if p.cfg.disableVersioning {
|
|
||||||
downMigrations := p.migrations
|
|
||||||
if downByOne {
|
|
||||||
last := p.migrations[len(p.migrations)-1]
|
|
||||||
downMigrations = []*migration{last}
|
|
||||||
}
|
|
||||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
|
||||||
}
|
|
||||||
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(dbMigrations) == 0 {
|
|
||||||
return nil, errMissingZeroVersion
|
|
||||||
}
|
|
||||||
if dbMigrations[0].Version == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
var downMigrations []*migration
|
|
||||||
for _, dbMigration := range dbMigrations {
|
|
||||||
if dbMigration.Version <= version {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
m, err := p.getMigration(dbMigration.Version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
downMigrations = append(downMigrations, m)
|
|
||||||
}
|
|
||||||
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
|
|
||||||
}
|
|
||||||
|
|
||||||
// runMigrations runs migrations sequentially in the given direction.
|
|
||||||
//
|
|
||||||
// If the migrations list is empty, return nil without error.
|
|
||||||
func (p *Provider) runMigrations(
|
|
||||||
ctx context.Context,
|
|
||||||
conn *sql.Conn,
|
|
||||||
migrations []*migration,
|
|
||||||
direction sqlparser.Direction,
|
|
||||||
byOne bool,
|
|
||||||
) ([]*MigrationResult, error) {
|
|
||||||
if len(migrations) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
apply := migrations
|
|
||||||
if byOne {
|
|
||||||
apply = migrations[:1]
|
|
||||||
}
|
|
||||||
// Lazily parse SQL migrations (if any) in both directions. We do this before running any
|
|
||||||
// migrations so that we can fail fast if there are any errors and avoid leaving the database in
|
|
||||||
// a partially migrated state.
|
|
||||||
if err := parseSQL(p.fsys, false, apply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
|
|
||||||
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
|
|
||||||
// to run in a transaction.
|
|
||||||
|
|
||||||
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but
|
|
||||||
// are locking the database with *sql.Conn. If the caller sets max open connections to 1, then
|
|
||||||
// this will deadlock because the Go migration will try to acquire a connection from the pool,
|
|
||||||
// but the pool is locked.
|
|
||||||
//
|
|
||||||
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to use
|
|
||||||
// *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of
|
|
||||||
// an edge case.
|
|
||||||
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
|
|
||||||
for _, m := range apply {
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeGo:
|
|
||||||
if m.Go != nil && m.useTx(direction.ToBool()) {
|
|
||||||
return nil, errors.New("potential deadlock detected: cannot run Go migrations without a transaction when max open connections set to 1")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Avoid allocating a slice because we may have a partial migration error.
|
|
||||||
// 1. Avoid giving the impression that N migrations were applied when in fact some were not
|
|
||||||
// 2. Avoid the caller having to check for nil results
|
|
||||||
var results []*MigrationResult
|
|
||||||
for _, m := range apply {
|
|
||||||
current := &MigrationResult{
|
|
||||||
Source: m.Source,
|
|
||||||
Direction: direction.String(),
|
|
||||||
Empty: m.isEmpty(direction.ToBool()),
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil {
|
|
||||||
// TODO(mf): we should also return the pending migrations here, the remaining items in
|
|
||||||
// the apply slice.
|
|
||||||
current.Error = err
|
|
||||||
current.Duration = time.Since(start)
|
|
||||||
return nil, &PartialError{
|
|
||||||
Applied: results,
|
|
||||||
Failed: current,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
current.Duration = time.Since(start)
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) runIndividually(
|
|
||||||
ctx context.Context,
|
|
||||||
conn *sql.Conn,
|
|
||||||
direction bool,
|
|
||||||
m *migration,
|
|
||||||
) error {
|
|
||||||
if m.useTx(direction) {
|
|
||||||
// Run the migration in a transaction.
|
|
||||||
return p.beginTx(ctx, conn, func(tx *sql.Tx) error {
|
|
||||||
if err := m.run(ctx, tx, direction); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if p.cfg.disableVersioning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if direction {
|
|
||||||
return p.store.Insert(ctx, tx, database.InsertRequest{Version: m.Source.Version})
|
|
||||||
}
|
|
||||||
return p.store.Delete(ctx, tx, m.Source.Version)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// Run the migration outside of a transaction.
|
|
||||||
switch m.Source.Type {
|
|
||||||
case TypeGo:
|
|
||||||
// Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the
|
|
||||||
// GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open
|
|
||||||
// connections to 1. See the comment in runMigrations for more details.
|
|
||||||
if err := m.runNoTx(ctx, p.db, direction); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case TypeSQL:
|
|
||||||
if err := m.runConn(ctx, conn, direction); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.cfg.disableVersioning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if direction {
|
|
||||||
return p.store.Insert(ctx, conn, database.InsertRequest{Version: m.Source.Version})
|
|
||||||
}
|
|
||||||
return p.store.Delete(ctx, conn, m.Source.Version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// beginTx begins a transaction and runs the given function. If the function returns an error, the
|
|
||||||
// transaction is rolled back. Otherwise, the transaction is committed.
|
|
||||||
//
|
|
||||||
// If the provider is configured to use versioning, this function also inserts or deletes the
|
|
||||||
// migration version.
|
|
||||||
func (p *Provider) beginTx(
|
|
||||||
ctx context.Context,
|
|
||||||
conn *sql.Conn,
|
|
||||||
fn func(tx *sql.Tx) error,
|
|
||||||
) (retErr error) {
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if retErr != nil {
|
|
||||||
retErr = multierr.Append(retErr, tx.Rollback())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if err := fn(tx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
|
|
||||||
p.mu.Lock()
|
|
||||||
conn, err := p.db.Conn(ctx)
|
|
||||||
if err != nil {
|
|
||||||
p.mu.Unlock()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
// cleanup is a function that cleans up the connection, and optionally, the session lock.
|
|
||||||
cleanup := func() error {
|
|
||||||
p.mu.Unlock()
|
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
|
|
||||||
if err := l.SessionLock(ctx, conn); err != nil {
|
|
||||||
return nil, nil, multierr.Append(err, cleanup())
|
|
||||||
}
|
|
||||||
cleanup = func() error {
|
|
||||||
p.mu.Unlock()
|
|
||||||
// Use a detached context to unlock the session. This is because the context passed to
|
|
||||||
// SessionLock may have been canceled, and we don't want to cancel the unlock. TODO(mf):
|
|
||||||
// use [context.WithoutCancel] added in go1.21
|
|
||||||
detachedCtx := context.Background()
|
|
||||||
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
|
|
||||||
// need the version table because there is no versioning.
|
|
||||||
if !p.cfg.disableVersioning {
|
|
||||||
if err := p.ensureVersionTable(ctx, conn); err != nil {
|
|
||||||
return nil, nil, multierr.Append(err, cleanup())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return conn, cleanup, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
|
|
||||||
// will not be parsed again.
|
|
||||||
//
|
|
||||||
// Important: This function will mutate SQL migrations and is not safe for concurrent use.
|
|
||||||
func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error {
|
|
||||||
for _, m := range migrations {
|
|
||||||
// If the migration is a SQL migration, and it has not been parsed, parse it.
|
|
||||||
if m.Source.Type == TypeSQL && m.SQL == nil {
|
|
||||||
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Path, debug)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.SQL = &sqlMigration{
|
|
||||||
UseTx: parsed.UseTx,
|
|
||||||
UpStatements: parsed.Up,
|
|
||||||
DownStatements: parsed.Down,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
|
|
||||||
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
|
|
||||||
// from a table that may not exist. https://github.com/pressly/goose/issues/461
|
|
||||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
|
||||||
if err == nil && res != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return p.beginTx(ctx, conn, func(tx *sql.Tx) error {
|
|
||||||
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if p.cfg.disableVersioning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type missingMigration struct {
|
|
||||||
versionID int64
|
|
||||||
filename string
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
|
|
||||||
// migration is one that has a version less than the max version in the database.
|
|
||||||
func checkMissingMigrations(
|
|
||||||
dbMigrations []*database.ListMigrationsResult,
|
|
||||||
fsMigrations []*migration,
|
|
||||||
) []missingMigration {
|
|
||||||
existing := make(map[int64]bool)
|
|
||||||
var dbMaxVersion int64
|
|
||||||
for _, m := range dbMigrations {
|
|
||||||
existing[m.Version] = true
|
|
||||||
if m.Version > dbMaxVersion {
|
|
||||||
dbMaxVersion = m.Version
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var missing []missingMigration
|
|
||||||
for _, m := range fsMigrations {
|
|
||||||
version := m.Source.Version
|
|
||||||
if !existing[version] && version < dbMaxVersion {
|
|
||||||
missing = append(missing, missingMigration{
|
|
||||||
versionID: version,
|
|
||||||
filename: m.filename(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort.Slice(missing, func(i, j int) bool {
|
|
||||||
return missing[i].versionID < missing[j].versionID
|
|
||||||
})
|
|
||||||
return missing
|
|
||||||
}
|
|
||||||
|
|
||||||
// getMigration returns the migration with the given version. If no migration is found, then
|
|
||||||
// ErrVersionNotFound is returned.
|
|
||||||
func (p *Provider) getMigration(version int64) (*migration, error) {
|
|
||||||
for _, m := range p.migrations {
|
|
||||||
if m.Source.Version == version {
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, ErrVersionNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
|
|
||||||
m, err := p.getMigration(version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, cleanup, err := p.initialize(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
retErr = multierr.Append(retErr, cleanup())
|
|
||||||
}()
|
|
||||||
|
|
||||||
result, err := p.store.GetMigration(ctx, conn, version)
|
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// If the migration has already been applied, return an error, unless the migration is being
|
|
||||||
// applied in the opposite direction. In that case, we allow the migration to be applied again.
|
|
||||||
if result != nil && direction {
|
|
||||||
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
|
|
||||||
}
|
|
||||||
|
|
||||||
d := sqlparser.DirectionDown
|
|
||||||
if direction {
|
|
||||||
d = sqlparser.DirectionUp
|
|
||||||
}
|
|
||||||
results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(results) == 0 {
|
|
||||||
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
|
|
||||||
}
|
|
||||||
return results[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
|
|
||||||
conn, cleanup, err := p.initialize(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
retErr = multierr.Append(retErr, cleanup())
|
|
||||||
}()
|
|
||||||
|
|
||||||
// TODO(mf): add support for limit and order. Also would be nice to refactor the list query to
|
|
||||||
// support limiting the set.
|
|
||||||
|
|
||||||
status := make([]*MigrationStatus, 0, len(p.migrations))
|
|
||||||
for _, m := range p.migrations {
|
|
||||||
migrationStatus := &MigrationStatus{
|
|
||||||
Source: m.Source,
|
|
||||||
State: StatePending,
|
|
||||||
}
|
|
||||||
dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version)
|
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if dbResult != nil {
|
|
||||||
migrationStatus.State = StateApplied
|
|
||||||
migrationStatus.AppliedAt = dbResult.Timestamp
|
|
||||||
}
|
|
||||||
status = append(status, migrationStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
return status, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) {
|
|
||||||
conn, cleanup, err := p.initialize(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
retErr = multierr.Append(retErr, cleanup())
|
|
||||||
}()
|
|
||||||
|
|
||||||
res, err := p.store.ListMigrations(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if len(res) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
sort.Slice(res, func(i, j int) bool {
|
|
||||||
return res[i].Version > res[j].Version
|
|
||||||
})
|
|
||||||
return res[0].Version, nil
|
|
||||||
}
|
|
|
@ -13,17 +13,25 @@ import (
|
||||||
// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive
|
// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive
|
||||||
// session-level advisory lock mechanism.
|
// session-level advisory lock mechanism.
|
||||||
//
|
//
|
||||||
// This function creates a SessionLocker that can be used to acquire and release locks for
|
// This function creates a SessionLocker that can be used to acquire and release a lock for
|
||||||
// synchronization purposes. The lock acquisition is retried until it is successfully acquired or
|
// synchronization purposes. The lock acquisition is retried until it is successfully acquired or
|
||||||
// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the
|
// until the failure threshold is reached. The default lock duration is set to 5 minutes, and the
|
||||||
// default unlock duration is set to 1 minute.
|
// default unlock duration is set to 1 minute.
|
||||||
//
|
//
|
||||||
|
// If you have long running migrations, you may want to increase the lock duration.
|
||||||
|
//
|
||||||
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
|
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
|
||||||
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
|
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
|
||||||
cfg := sessionLockerConfig{
|
cfg := sessionLockerConfig{
|
||||||
lockID: DefaultLockID,
|
lockID: DefaultLockID,
|
||||||
lockTimeout: DefaultLockTimeout,
|
lockProbe: probe{
|
||||||
unlockTimeout: DefaultUnlockTimeout,
|
periodSeconds: 5 * time.Second,
|
||||||
|
failureThreshold: 60,
|
||||||
|
},
|
||||||
|
unlockProbe: probe{
|
||||||
|
periodSeconds: 2 * time.Second,
|
||||||
|
failureThreshold: 30,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt.apply(&cfg); err != nil {
|
if err := opt.apply(&cfg); err != nil {
|
||||||
|
@ -32,13 +40,13 @@ func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error
|
||||||
}
|
}
|
||||||
return &postgresSessionLocker{
|
return &postgresSessionLocker{
|
||||||
lockID: cfg.lockID,
|
lockID: cfg.lockID,
|
||||||
retryLock: retry.WithMaxDuration(
|
retryLock: retry.WithMaxRetries(
|
||||||
cfg.lockTimeout,
|
cfg.lockProbe.failureThreshold,
|
||||||
retry.NewConstant(2*time.Second),
|
retry.NewConstant(cfg.lockProbe.periodSeconds),
|
||||||
),
|
),
|
||||||
retryUnlock: retry.WithMaxDuration(
|
retryUnlock: retry.WithMaxRetries(
|
||||||
cfg.unlockTimeout,
|
cfg.unlockProbe.failureThreshold,
|
||||||
retry.NewConstant(2*time.Second),
|
retry.NewConstant(cfg.unlockProbe.periodSeconds),
|
||||||
),
|
),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/pressly/goose/v3/internal/check"
|
||||||
"github.com/pressly/goose/v3/internal/testdb"
|
"github.com/pressly/goose/v3/internal/testdb"
|
||||||
|
@ -30,8 +29,8 @@ func TestPostgresSessionLocker(t *testing.T) {
|
||||||
)
|
)
|
||||||
locker, err := lock.NewPostgresSessionLocker(
|
locker, err := lock.NewPostgresSessionLocker(
|
||||||
lock.WithLockID(lockID),
|
lock.WithLockID(lockID),
|
||||||
lock.WithLockTimeout(4*time.Second),
|
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||||
lock.WithUnlockTimeout(4*time.Second),
|
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -60,8 +59,8 @@ func TestPostgresSessionLocker(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("lock_close_conn_unlock", func(t *testing.T) {
|
t.Run("lock_close_conn_unlock", func(t *testing.T) {
|
||||||
locker, err := lock.NewPostgresSessionLocker(
|
locker, err := lock.NewPostgresSessionLocker(
|
||||||
lock.WithLockTimeout(4*time.Second),
|
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||||
lock.WithUnlockTimeout(4*time.Second),
|
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -103,10 +102,12 @@ func TestPostgresSessionLocker(t *testing.T) {
|
||||||
// Exactly one connection should acquire the lock. While the other connections
|
// Exactly one connection should acquire the lock. While the other connections
|
||||||
// should fail to acquire the lock and timeout.
|
// should fail to acquire the lock and timeout.
|
||||||
locker, err := lock.NewPostgresSessionLocker(
|
locker, err := lock.NewPostgresSessionLocker(
|
||||||
lock.WithLockTimeout(4*time.Second),
|
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||||
lock.WithUnlockTimeout(4*time.Second),
|
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
// NOTE, we are not unlocking the lock, because we want to test that the lock is
|
||||||
|
// released when the connection is closed.
|
||||||
ch <- locker.SessionLock(ctx, conn)
|
ch <- locker.SessionLock(ctx, conn)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -138,8 +139,8 @@ func TestPostgresSessionLocker(t *testing.T) {
|
||||||
)
|
)
|
||||||
locker, err := lock.NewPostgresSessionLocker(
|
locker, err := lock.NewPostgresSessionLocker(
|
||||||
lock.WithLockID(lockID),
|
lock.WithLockID(lockID),
|
||||||
lock.WithLockTimeout(4*time.Second),
|
lock.WithLockTimeout(1, 4), // 4 second timeout
|
||||||
lock.WithUnlockTimeout(4*time.Second),
|
lock.WithUnlockTimeout(1, 4), // 4 second timeout
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
|
||||||
|
@ -179,6 +180,7 @@ func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
var pgLocks []pgLock
|
var pgLocks []pgLock
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var p pgLock
|
var p pgLock
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package lock
|
package lock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,11 +11,6 @@ const (
|
||||||
//
|
//
|
||||||
// crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA))
|
// crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA))
|
||||||
DefaultLockID int64 = 5887940537704921958
|
DefaultLockID int64 = 5887940537704921958
|
||||||
|
|
||||||
// Default values for the lock (time to wait for the lock to be acquired) and unlock (time to
|
|
||||||
// wait for the lock to be released) wait durations.
|
|
||||||
DefaultLockTimeout time.Duration = 60 * time.Minute
|
|
||||||
DefaultUnlockTimeout time.Duration = 1 * time.Minute
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionLockerOption is used to configure a SessionLocker.
|
// SessionLockerOption is used to configure a SessionLocker.
|
||||||
|
@ -32,26 +28,65 @@ func WithLockID(lockID int64) SessionLockerOption {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithLockTimeout sets the max duration to wait for the lock to be acquired.
|
// WithLockTimeout sets the max duration to wait for the lock to be acquired. The total duration
|
||||||
func WithLockTimeout(duration time.Duration) SessionLockerOption {
|
// will be the period times the failure threshold.
|
||||||
|
//
|
||||||
|
// By default, the lock timeout is 300s (5min), where the lock is retried every 5 seconds (period)
|
||||||
|
// up to 60 times (failure threshold).
|
||||||
|
//
|
||||||
|
// The minimum period is 1 second, and the minimum failure threshold is 1.
|
||||||
|
func WithLockTimeout(period, failureThreshold uint64) SessionLockerOption {
|
||||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||||
c.lockTimeout = duration
|
if period < 1 {
|
||||||
|
return errors.New("period must be greater than 0, minimum is 1")
|
||||||
|
}
|
||||||
|
if failureThreshold < 1 {
|
||||||
|
return errors.New("failure threshold must be greater than 0, minimum is 1")
|
||||||
|
}
|
||||||
|
c.lockProbe = probe{
|
||||||
|
periodSeconds: time.Duration(period) * time.Second,
|
||||||
|
failureThreshold: failureThreshold,
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUnlockTimeout sets the max duration to wait for the lock to be released.
|
// WithUnlockTimeout sets the max duration to wait for the lock to be released. The total duration
|
||||||
func WithUnlockTimeout(duration time.Duration) SessionLockerOption {
|
// will be the period times the failure threshold.
|
||||||
|
//
|
||||||
|
// By default, the lock timeout is 60s, where the lock is retried every 2 seconds (period) up to 30
|
||||||
|
// times (failure threshold).
|
||||||
|
//
|
||||||
|
// The minimum period is 1 second, and the minimum failure threshold is 1.
|
||||||
|
func WithUnlockTimeout(period, failureThreshold uint64) SessionLockerOption {
|
||||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||||
c.unlockTimeout = duration
|
if period < 1 {
|
||||||
|
return errors.New("period must be greater than 0, minimum is 1")
|
||||||
|
}
|
||||||
|
if failureThreshold < 1 {
|
||||||
|
return errors.New("failure threshold must be greater than 0, minimum is 1")
|
||||||
|
}
|
||||||
|
c.unlockProbe = probe{
|
||||||
|
periodSeconds: time.Duration(period) * time.Second,
|
||||||
|
failureThreshold: failureThreshold,
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type sessionLockerConfig struct {
|
type sessionLockerConfig struct {
|
||||||
lockID int64
|
lockID int64
|
||||||
lockTimeout time.Duration
|
lockProbe probe
|
||||||
unlockTimeout time.Duration
|
unlockProbe probe
|
||||||
|
}
|
||||||
|
|
||||||
|
// probe is used to configure how often and how many times to retry a lock or unlock operation. The
|
||||||
|
// total timeout will be the period times the failure threshold.
|
||||||
|
type probe struct {
|
||||||
|
// How often (in seconds) to perform the probe.
|
||||||
|
periodSeconds time.Duration
|
||||||
|
// Number of times to retry the probe.
|
||||||
|
failureThreshold uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)
|
var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)
|
||||||
|
|
54
migration.go
54
migration.go
|
@ -18,22 +18,36 @@ import (
|
||||||
// Both up and down functions may be nil, in which case the migration will be recorded in the
|
// Both up and down functions may be nil, in which case the migration will be recorded in the
|
||||||
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
|
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
|
||||||
// a version without running any functions. See [GoFunc] for more details.
|
// a version without running any functions. See [GoFunc] for more details.
|
||||||
func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
func NewGoMigration(version int64, up, down *GoFunc) *Migration {
|
||||||
m := Migration{
|
m := &Migration{
|
||||||
Type: TypeGo,
|
Type: TypeGo,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
Version: version,
|
Version: version,
|
||||||
Next: -1, Previous: -1,
|
Next: -1, Previous: -1,
|
||||||
goUp: up,
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
goDown: down,
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
construct: true,
|
construct: true,
|
||||||
}
|
}
|
||||||
|
updateMode := func(f *GoFunc) *GoFunc {
|
||||||
|
// infer mode from function
|
||||||
|
if f.Mode == 0 {
|
||||||
|
if f.RunTx != nil && f.RunDB == nil {
|
||||||
|
f.Mode = TransactionEnabled
|
||||||
|
}
|
||||||
|
if f.RunTx == nil && f.RunDB != nil {
|
||||||
|
f.Mode = TransactionDisabled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f
|
||||||
|
}
|
||||||
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
|
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
|
||||||
// we will remove these fields in favor of [GoFunc].
|
// we will remove these fields in favor of [GoFunc].
|
||||||
//
|
//
|
||||||
// Note, this function does not do any validation. Validation is lazily done when the migration
|
// Note, this function does not do any validation. Validation is lazily done when the migration
|
||||||
// is registered.
|
// is registered.
|
||||||
if up != nil {
|
if up != nil {
|
||||||
|
m.goUp = updateMode(up)
|
||||||
|
|
||||||
if up.RunDB != nil {
|
if up.RunDB != nil {
|
||||||
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
|
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
|
||||||
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
|
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
|
||||||
|
@ -45,6 +59,8 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if down != nil {
|
if down != nil {
|
||||||
|
m.goDown = updateMode(down)
|
||||||
|
|
||||||
if down.RunDB != nil {
|
if down.RunDB != nil {
|
||||||
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
|
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
|
||||||
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
|
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
|
||||||
|
@ -55,12 +71,6 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
||||||
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
|
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if m.goUp == nil {
|
|
||||||
m.goUp = &GoFunc{Mode: TransactionEnabled}
|
|
||||||
}
|
|
||||||
if m.goDown == nil {
|
|
||||||
m.goDown = &GoFunc{Mode: TransactionEnabled}
|
|
||||||
}
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,10 +86,6 @@ type Migration struct {
|
||||||
|
|
||||||
UpFnContext, DownFnContext GoMigrationContext
|
UpFnContext, DownFnContext GoMigrationContext
|
||||||
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
|
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
|
||||||
// These fields are used internally by goose and users are not expected to set them. Instead,
|
|
||||||
// use [NewGoMigration] to create a new go migration.
|
|
||||||
construct bool
|
|
||||||
goUp, goDown *GoFunc
|
|
||||||
|
|
||||||
// These fields will be removed in a future major version. They are here for backwards
|
// These fields will be removed in a future major version. They are here for backwards
|
||||||
// compatibility and are an implementation detail.
|
// compatibility and are an implementation detail.
|
||||||
|
@ -98,6 +104,26 @@ type Migration struct {
|
||||||
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
|
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
|
||||||
|
|
||||||
noVersioning bool
|
noVersioning bool
|
||||||
|
|
||||||
|
// These fields are used internally by goose and users are not expected to set them. Instead,
|
||||||
|
// use [NewGoMigration] to create a new go migration.
|
||||||
|
construct bool
|
||||||
|
goUp, goDown *GoFunc
|
||||||
|
|
||||||
|
sql sqlMigration
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlMigration struct {
|
||||||
|
// The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
|
||||||
|
// optimization to avoid parsing migrations that may never be needed. Typically, migrations are
|
||||||
|
// incremental, and users often run only the most recent ones, making parsing of prior
|
||||||
|
// migrations unnecessary in most cases.
|
||||||
|
Parsed bool
|
||||||
|
|
||||||
|
// Parsed must be set to true before the following fields are used.
|
||||||
|
UseTx bool
|
||||||
|
Up []string
|
||||||
|
Down []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoFunc represents a Go migration function.
|
// GoFunc represents a Go migration function.
|
||||||
|
|
8
osfs.go
8
osfs.go
|
@ -18,3 +18,11 @@ func (osFS) Stat(name string) (fs.FileInfo, error) { return os.Stat(filepath.Fro
|
||||||
func (osFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(filepath.FromSlash(name)) }
|
func (osFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(filepath.FromSlash(name)) }
|
||||||
|
|
||||||
func (osFS) Glob(pattern string) ([]string, error) { return filepath.Glob(filepath.FromSlash(pattern)) }
|
func (osFS) Glob(pattern string) ([]string, error) { return filepath.Glob(filepath.FromSlash(pattern)) }
|
||||||
|
|
||||||
|
type noopFS struct{}
|
||||||
|
|
||||||
|
var _ fs.FS = noopFS{}
|
||||||
|
|
||||||
|
func (f noopFS) Open(name string) (fs.File, error) {
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,477 @@
|
||||||
|
package goose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3/database"
|
||||||
|
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||||
|
"go.uber.org/multierr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider is a goose migration provider.
|
||||||
|
type Provider struct {
|
||||||
|
// mu protects all accesses to the provider and must be held when calling operations on the
|
||||||
|
// database.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
db *sql.DB
|
||||||
|
store database.Store
|
||||||
|
|
||||||
|
fsys fs.FS
|
||||||
|
cfg config
|
||||||
|
|
||||||
|
// migrations are ordered by version in ascending order.
|
||||||
|
migrations []*Migration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider returns a new goose provider.
|
||||||
|
//
|
||||||
|
// The caller is responsible for matching the database dialect with the database/sql driver. For
|
||||||
|
// example, if the database dialect is "postgres", the database/sql driver could be
|
||||||
|
// github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect]
|
||||||
|
// constant backed by a default [database.Store] implementation. For more advanced use cases, such
|
||||||
|
// as using a custom table name or supplying a custom store implementation, see [WithStore].
|
||||||
|
//
|
||||||
|
// fsys is the filesystem used to read migration files, but may be nil. Most users will want to use
|
||||||
|
// [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem.
|
||||||
|
// However, it is possible to use a different "filesystem", such as [embed.FS] or filter out
|
||||||
|
// migrations using [fs.Sub].
|
||||||
|
//
|
||||||
|
// See [ProviderOption] for more information on configuring the provider.
|
||||||
|
//
|
||||||
|
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
|
||||||
|
//
|
||||||
|
// Experimental: This API is experimental and may change in the future.
|
||||||
|
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, errors.New("db must not be nil")
|
||||||
|
}
|
||||||
|
if fsys == nil {
|
||||||
|
fsys = noopFS{}
|
||||||
|
}
|
||||||
|
cfg := config{
|
||||||
|
registered: make(map[int64]*Migration),
|
||||||
|
excludePaths: make(map[string]bool),
|
||||||
|
excludeVersions: make(map[int64]bool),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
if err := opt.apply(&cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Allow users to specify a custom store implementation, but only if they don't specify a
|
||||||
|
// dialect. If they specify a dialect, we'll use the default store implementation.
|
||||||
|
if dialect == "" && cfg.store == nil {
|
||||||
|
return nil, errors.New("dialect must not be empty")
|
||||||
|
}
|
||||||
|
if dialect != "" && cfg.store != nil {
|
||||||
|
return nil, errors.New("dialect must be empty when using a custom store implementation")
|
||||||
|
}
|
||||||
|
var store database.Store
|
||||||
|
if dialect != "" {
|
||||||
|
var err error
|
||||||
|
store, err = database.NewStore(dialect, DefaultTablename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
store = cfg.store
|
||||||
|
}
|
||||||
|
if store.Tablename() == "" {
|
||||||
|
return nil, errors.New("invalid store implementation: table name must not be empty")
|
||||||
|
}
|
||||||
|
return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProvider(
|
||||||
|
db *sql.DB,
|
||||||
|
store database.Store,
|
||||||
|
fsys fs.FS,
|
||||||
|
cfg config,
|
||||||
|
global map[int64]*Migration,
|
||||||
|
) (*Provider, error) {
|
||||||
|
// Collect migrations from the filesystem and merge with registered migrations.
|
||||||
|
//
|
||||||
|
// Note, we don't parse SQL migrations here. They are parsed lazily when required.
|
||||||
|
|
||||||
|
// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
|
||||||
|
// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
|
||||||
|
// we should make it optional.
|
||||||
|
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
versionToGoMigration := make(map[int64]*Migration)
|
||||||
|
// Add user-registered Go migrations from the provider.
|
||||||
|
for version, m := range cfg.registered {
|
||||||
|
versionToGoMigration[version] = m
|
||||||
|
}
|
||||||
|
// Add globally registered Go migrations.
|
||||||
|
for version, m := range global {
|
||||||
|
if _, ok := versionToGoMigration[version]; ok {
|
||||||
|
return nil, fmt.Errorf("global go migration with version %d previously registered with provider", version)
|
||||||
|
}
|
||||||
|
versionToGoMigration[version] = m
|
||||||
|
}
|
||||||
|
// At this point we have all registered unique Go migrations (if any). We need to merge them
|
||||||
|
// with SQL migrations from the filesystem.
|
||||||
|
migrations, err := merge(filesystemSources, versionToGoMigration)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(migrations) == 0 {
|
||||||
|
return nil, ErrNoMigrations
|
||||||
|
}
|
||||||
|
return &Provider{
|
||||||
|
db: db,
|
||||||
|
fsys: fsys,
|
||||||
|
cfg: cfg,
|
||||||
|
store: store,
|
||||||
|
migrations: migrations,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the status of all migrations, merging the list of migrations from the database and
|
||||||
|
// filesystem. The returned items are ordered by version, in ascending order.
|
||||||
|
func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
|
||||||
|
return p.status(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDBVersion returns the highest version recorded in the database, regardless of the order in
|
||||||
|
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
|
||||||
|
// this method returns 4. If no migrations have been applied, it returns 0.
|
||||||
|
func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) {
|
||||||
|
return p.getDBMaxVersion(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSources returns a list of all migration sources known to the provider, sorted in ascending
|
||||||
|
// order by version. The path field may be empty for manually registered migrations, such as Go
|
||||||
|
// migrations registered using the [WithGoMigrations] option.
|
||||||
|
func (p *Provider) ListSources() []*Source {
|
||||||
|
sources := make([]*Source, 0, len(p.migrations))
|
||||||
|
for _, m := range p.migrations {
|
||||||
|
sources = append(sources, &Source{
|
||||||
|
Type: m.Type,
|
||||||
|
Path: m.Source,
|
||||||
|
Version: m.Version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return sources
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping attempts to ping the database to verify a connection is available.
|
||||||
|
func (p *Provider) Ping(ctx context.Context) error {
|
||||||
|
return p.db.PingContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the database connection initially supplied to the provider.
|
||||||
|
func (p *Provider) Close() error {
|
||||||
|
return p.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyVersion applies exactly one migration for the specified version. If there is no migration
|
||||||
|
// available for the specified version, this method returns [ErrVersionNotFound]. If the migration
|
||||||
|
// has already been applied, this method returns [ErrAlreadyApplied].
|
||||||
|
//
|
||||||
|
// The direction parameter determines the migration direction: true for up migration and false for
|
||||||
|
// down migration.
|
||||||
|
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
|
||||||
|
res, err := p.apply(ctx, version, direction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// This should never happen, we must return exactly one result.
|
||||||
|
if len(res) != 1 {
|
||||||
|
versions := make([]string, 0, len(res))
|
||||||
|
for _, r := range res {
|
||||||
|
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"unexpected number of migrations applied running apply, expecting exactly one result: %v",
|
||||||
|
strings.Join(versions, ","),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return res[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
|
||||||
|
// empty list and nil error.
|
||||||
|
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
|
||||||
|
return p.up(ctx, false, math.MaxInt64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpByOne applies the next pending migration. If there is no next migration to apply, this method
|
||||||
|
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
|
||||||
|
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
|
||||||
|
res, err := p.up(ctx, true, math.MaxInt64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res) == 0 {
|
||||||
|
return nil, ErrNoNextVersion
|
||||||
|
}
|
||||||
|
// This should never happen, we must return exactly one result.
|
||||||
|
if len(res) != 1 {
|
||||||
|
versions := make([]string, 0, len(res))
|
||||||
|
for _, r := range res {
|
||||||
|
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"unexpected number of migrations applied running up-by-one, expecting exactly one result: %v",
|
||||||
|
strings.Join(versions, ","),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return res[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpTo applies all pending migrations up to, and including, the specified version. If there are no
|
||||||
|
// migrations to apply, this method returns empty list and nil error.
|
||||||
|
//
|
||||||
|
// For example, if there are three new migrations (9,10,11) and the current database version is 8
|
||||||
|
// with a requested version of 10, only versions 9,10 will be applied.
|
||||||
|
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||||
|
return p.up(ctx, false, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Down rolls back the most recently applied migration. If there are no migrations to rollback, this
|
||||||
|
// method returns [ErrNoNextVersion].
|
||||||
|
//
|
||||||
|
// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
|
||||||
|
// the migration version. This only applies in scenarios where migrations are allowed to be applied
|
||||||
|
// out of order.
|
||||||
|
func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
|
||||||
|
res, err := p.down(ctx, true, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res) == 0 {
|
||||||
|
return nil, ErrNoNextVersion
|
||||||
|
}
|
||||||
|
// This should never happen, we must return exactly one result.
|
||||||
|
if len(res) != 1 {
|
||||||
|
versions := make([]string, 0, len(res))
|
||||||
|
for _, r := range res {
|
||||||
|
versions = append(versions, strconv.FormatInt(r.Source.Version, 10))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"unexpected number of migrations applied running down, expecting exactly one result: %v",
|
||||||
|
strings.Join(versions, ","),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return res[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownTo rolls back all migrations down to, but not including, the specified version.
|
||||||
|
//
|
||||||
|
// For example, if the current database version is 11,10,9... and the requested version is 9, only
|
||||||
|
// migrations 11, 10 will be rolled back.
|
||||||
|
//
|
||||||
|
// Note, migrations are rolled back in the order they were applied. And not in the reverse order of
|
||||||
|
// the migration version. This only applies in scenarios where migrations are allowed to be applied
|
||||||
|
// out of order.
|
||||||
|
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
|
||||||
|
if version < 0 {
|
||||||
|
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
|
||||||
|
}
|
||||||
|
return p.down(ctx, false, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// *** Internal methods ***
|
||||||
|
|
||||||
|
func (p *Provider) up(
|
||||||
|
ctx context.Context,
|
||||||
|
byOne bool,
|
||||||
|
version int64,
|
||||||
|
) (_ []*MigrationResult, retErr error) {
|
||||||
|
if version < 1 {
|
||||||
|
return nil, errInvalidVersion
|
||||||
|
}
|
||||||
|
conn, cleanup, err := p.initialize(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
retErr = multierr.Append(retErr, cleanup())
|
||||||
|
}()
|
||||||
|
|
||||||
|
if len(p.migrations) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var apply []*Migration
|
||||||
|
if p.cfg.disableVersioning {
|
||||||
|
if byOne {
|
||||||
|
return nil, errors.New("up-by-one not supported when versioning is disabled")
|
||||||
|
}
|
||||||
|
apply = p.migrations
|
||||||
|
} else {
|
||||||
|
// optimize(mf): Listing all migrations from the database isn't great. This is only required
|
||||||
|
// to support the allow missing (out-of-order) feature. For users that don't use this
|
||||||
|
// feature, we could just query the database for the current max version and then apply
|
||||||
|
// migrations greater than that version.
|
||||||
|
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(dbMigrations) == 0 {
|
||||||
|
return nil, errMissingZeroVersion
|
||||||
|
}
|
||||||
|
apply, err = p.resolveUpMigrations(dbMigrations, version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) down(
|
||||||
|
ctx context.Context,
|
||||||
|
byOne bool,
|
||||||
|
version int64,
|
||||||
|
) (_ []*MigrationResult, retErr error) {
|
||||||
|
conn, cleanup, err := p.initialize(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
retErr = multierr.Append(retErr, cleanup())
|
||||||
|
}()
|
||||||
|
|
||||||
|
if len(p.migrations) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if p.cfg.disableVersioning {
|
||||||
|
var downMigrations []*Migration
|
||||||
|
if byOne {
|
||||||
|
last := p.migrations[len(p.migrations)-1]
|
||||||
|
downMigrations = []*Migration{last}
|
||||||
|
} else {
|
||||||
|
downMigrations = p.migrations
|
||||||
|
}
|
||||||
|
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne)
|
||||||
|
}
|
||||||
|
dbMigrations, err := p.store.ListMigrations(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(dbMigrations) == 0 {
|
||||||
|
return nil, errMissingZeroVersion
|
||||||
|
}
|
||||||
|
// We never migrate the zero version down.
|
||||||
|
if dbMigrations[0].Version == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var apply []*Migration
|
||||||
|
for _, dbMigration := range dbMigrations {
|
||||||
|
if dbMigration.Version <= version {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
m, err := p.getMigration(dbMigration.Version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
apply = append(apply, m)
|
||||||
|
}
|
||||||
|
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) apply(
|
||||||
|
ctx context.Context,
|
||||||
|
version int64,
|
||||||
|
direction bool,
|
||||||
|
) (_ []*MigrationResult, retErr error) {
|
||||||
|
if version < 1 {
|
||||||
|
return nil, errInvalidVersion
|
||||||
|
}
|
||||||
|
m, err := p.getMigration(version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, cleanup, err := p.initialize(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
retErr = multierr.Append(retErr, cleanup())
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := p.store.GetMigration(ctx, conn, version)
|
||||||
|
if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// If the migration has already been applied, return an error. But, if the migration is being
|
||||||
|
// rolled back, we allow the individual migration to be applied again.
|
||||||
|
if result != nil && direction {
|
||||||
|
return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied)
|
||||||
|
}
|
||||||
|
d := sqlparser.DirectionDown
|
||||||
|
if direction {
|
||||||
|
d = sqlparser.DirectionUp
|
||||||
|
}
|
||||||
|
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
|
||||||
|
conn, cleanup, err := p.initialize(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
retErr = multierr.Append(retErr, cleanup())
|
||||||
|
}()
|
||||||
|
|
||||||
|
status := make([]*MigrationStatus, 0, len(p.migrations))
|
||||||
|
for _, m := range p.migrations {
|
||||||
|
migrationStatus := &MigrationStatus{
|
||||||
|
Source: &Source{
|
||||||
|
Type: m.Type,
|
||||||
|
Path: m.Source,
|
||||||
|
Version: m.Version,
|
||||||
|
},
|
||||||
|
State: StatePending,
|
||||||
|
}
|
||||||
|
dbResult, err := p.store.GetMigration(ctx, conn, m.Version)
|
||||||
|
if err != nil && !errors.Is(err, database.ErrVersionNotFound) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if dbResult != nil {
|
||||||
|
migrationStatus.State = StateApplied
|
||||||
|
migrationStatus.AppliedAt = dbResult.Timestamp
|
||||||
|
}
|
||||||
|
status = append(status, migrationStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) {
|
||||||
|
conn, cleanup, err := p.initialize(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
retErr = multierr.Append(retErr, cleanup())
|
||||||
|
}()
|
||||||
|
|
||||||
|
res, err := p.store.ListMigrations(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if len(res) == 0 {
|
||||||
|
return 0, errMissingZeroVersion
|
||||||
|
}
|
||||||
|
// Sort in descending order.
|
||||||
|
sort.Slice(res, func(i, j int) bool {
|
||||||
|
return res[i].Version > res[j].Version
|
||||||
|
})
|
||||||
|
// Return the highest version.
|
||||||
|
return res[0].Version, nil
|
||||||
|
}
|
|
@ -1,15 +1,12 @@
|
||||||
package provider
|
package goose
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// fileSources represents a collection of migration files on the filesystem.
|
// fileSources represents a collection of migration files on the filesystem.
|
||||||
|
@ -18,25 +15,6 @@ type fileSources struct {
|
||||||
goSources []Source
|
goSources []Source
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(mf): remove?
|
|
||||||
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
|
|
||||||
switch t {
|
|
||||||
case TypeGo:
|
|
||||||
for _, source := range s.goSources {
|
|
||||||
if source.Version == version {
|
|
||||||
return &source
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case TypeSQL:
|
|
||||||
for _, source := range s.sqlSources {
|
|
||||||
if source.Version == version {
|
|
||||||
return &source
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectFilesystemSources scans the file system for migration files that have a numeric prefix
|
// collectFilesystemSources scans the file system for migration files that have a numeric prefix
|
||||||
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
|
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
|
||||||
// be nil, in which case an empty fileSources is returned.
|
// be nil, in which case an empty fileSources is returned.
|
||||||
|
@ -46,7 +24,12 @@ func (s *fileSources) lookup(t MigrationType, version int64) *Source {
|
||||||
//
|
//
|
||||||
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
|
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
|
||||||
// migration sources from the filesystem.
|
// migration sources from the filesystem.
|
||||||
func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
|
func collectFilesystemSources(
|
||||||
|
fsys fs.FS,
|
||||||
|
strict bool,
|
||||||
|
excludePaths map[string]bool,
|
||||||
|
excludeVersions map[int64]bool,
|
||||||
|
) (*fileSources, error) {
|
||||||
if fsys == nil {
|
if fsys == nil {
|
||||||
return new(fileSources), nil
|
return new(fileSources), nil
|
||||||
}
|
}
|
||||||
|
@ -62,8 +45,11 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
||||||
}
|
}
|
||||||
for _, fullpath := range files {
|
for _, fullpath := range files {
|
||||||
base := filepath.Base(fullpath)
|
base := filepath.Base(fullpath)
|
||||||
// Skip explicit excludes or Go test files.
|
if strings.HasSuffix(base, "_test.go") {
|
||||||
if excludes[base] || strings.HasSuffix(base, "_test.go") {
|
continue
|
||||||
|
}
|
||||||
|
if excludePaths[base] {
|
||||||
|
// TODO(mf): log this?
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
|
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
|
||||||
|
@ -71,13 +57,17 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
||||||
// filenames, but still have versioned migrations within the same directory. For
|
// filenames, but still have versioned migrations within the same directory. For
|
||||||
// example, a user could have a helpers.go file which contains unexported helper
|
// example, a user could have a helpers.go file which contains unexported helper
|
||||||
// functions for migrations.
|
// functions for migrations.
|
||||||
version, err := goose.NumericComponent(base)
|
version, err := NumericComponent(base)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strict {
|
if strict {
|
||||||
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
|
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if excludeVersions[version] {
|
||||||
|
// TODO: log this?
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Ensure there are no duplicate versions.
|
// Ensure there are no duplicate versions.
|
||||||
if existing, ok := versionToBaseLookup[version]; ok {
|
if existing, ok := versionToBaseLookup[version]; ok {
|
||||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||||
|
@ -101,7 +91,7 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
// Should never happen since we already filtered out all other file types.
|
// Should never happen since we already filtered out all other file types.
|
||||||
return nil, fmt.Errorf("unknown migration type: %s", base)
|
return nil, fmt.Errorf("invalid file extension: %q", base)
|
||||||
}
|
}
|
||||||
// Add the version to the lookup map.
|
// Add the version to the lookup map.
|
||||||
versionToBaseLookup[version] = base
|
versionToBaseLookup[version] = base
|
||||||
|
@ -110,15 +100,25 @@ func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool)
|
||||||
return sources, nil
|
return sources, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) {
|
func newSQLMigration(source Source) *Migration {
|
||||||
var migrations []*migration
|
return &Migration{
|
||||||
migrationLookup := make(map[int64]*migration)
|
Type: source.Type,
|
||||||
|
Version: source.Version,
|
||||||
|
Source: source.Path,
|
||||||
|
construct: true,
|
||||||
|
Next: -1, Previous: -1,
|
||||||
|
sql: sqlMigration{
|
||||||
|
Parsed: false, // SQL migrations are parsed lazily.
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func merge(sources *fileSources, registerd map[int64]*Migration) ([]*Migration, error) {
|
||||||
|
var migrations []*Migration
|
||||||
|
migrationLookup := make(map[int64]*Migration)
|
||||||
// Add all SQL migrations to the list of migrations.
|
// Add all SQL migrations to the list of migrations.
|
||||||
for _, source := range sources.sqlSources {
|
for _, source := range sources.sqlSources {
|
||||||
m := &migration{
|
m := newSQLMigration(source)
|
||||||
Source: source,
|
|
||||||
SQL: nil, // SQL migrations are parsed lazily.
|
|
||||||
}
|
|
||||||
migrations = append(migrations, m)
|
migrations = append(migrations, m)
|
||||||
migrationLookup[source.Version] = m
|
migrationLookup[source.Version] = m
|
||||||
}
|
}
|
||||||
|
@ -147,38 +147,24 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
|
||||||
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
|
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
|
||||||
// the SQL migration files.
|
// the SQL migration files.
|
||||||
for version, r := range registerd {
|
for version, r := range registerd {
|
||||||
fullpath := r.fullpath
|
|
||||||
if fullpath == "" {
|
|
||||||
if s := sources.lookup(TypeGo, version); s != nil {
|
|
||||||
fullpath = s.Path
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Ensure there are no duplicate versions.
|
// Ensure there are no duplicate versions.
|
||||||
if existing, ok := migrationLookup[version]; ok {
|
if existing, ok := migrationLookup[version]; ok {
|
||||||
fullpath := r.fullpath
|
fullpath := r.Source
|
||||||
if fullpath == "" {
|
if fullpath == "" {
|
||||||
fullpath = "manually registered (no source)"
|
fullpath = "manually registered (no source)"
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
|
||||||
version,
|
version,
|
||||||
existing.Source.Path,
|
existing.Source,
|
||||||
fullpath,
|
fullpath,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
m := &migration{
|
migrations = append(migrations, r)
|
||||||
Source: Source{
|
migrationLookup[version] = r
|
||||||
Type: TypeGo,
|
|
||||||
Path: fullpath, // May be empty if migration was registered manually.
|
|
||||||
Version: version,
|
|
||||||
},
|
|
||||||
Go: r,
|
|
||||||
}
|
|
||||||
migrations = append(migrations, m)
|
|
||||||
migrationLookup[version] = m
|
|
||||||
}
|
}
|
||||||
// Sort migrations by version in ascending order.
|
// Sort migrations by version in ascending order.
|
||||||
sort.Slice(migrations, func(i, j int) bool {
|
sort.Slice(migrations, func(i, j int) bool {
|
||||||
return migrations[i].Source.Version < migrations[j].Source.Version
|
return migrations[i].Version < migrations[j].Version
|
||||||
})
|
})
|
||||||
return migrations, nil
|
return migrations, nil
|
||||||
}
|
}
|
||||||
|
@ -203,11 +189,3 @@ func unregisteredError(unregistered []string) error {
|
||||||
|
|
||||||
return errors.New(b.String())
|
return errors.New(b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
type noopFS struct{}
|
|
||||||
|
|
||||||
var _ fs.FS = noopFS{}
|
|
||||||
|
|
||||||
func (f noopFS) Open(name string) (fs.File, error) {
|
|
||||||
return nil, os.ErrNotExist
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
package provider
|
package goose
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
@ -12,21 +12,21 @@ import (
|
||||||
func TestCollectFileSources(t *testing.T) {
|
func TestCollectFileSources(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("nil_fsys", func(t *testing.T) {
|
t.Run("nil_fsys", func(t *testing.T) {
|
||||||
sources, err := collectFilesystemSources(nil, false, nil)
|
sources, err := collectFilesystemSources(nil, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Bool(t, sources != nil, true)
|
check.Bool(t, sources != nil, true)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
check.Number(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
check.Number(t, len(sources.sqlSources), 0)
|
||||||
})
|
})
|
||||||
t.Run("noop_fsys", func(t *testing.T) {
|
t.Run("noop_fsys", func(t *testing.T) {
|
||||||
sources, err := collectFilesystemSources(noopFS{}, false, nil)
|
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Bool(t, sources != nil, true)
|
check.Bool(t, sources != nil, true)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
check.Number(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
check.Number(t, len(sources.sqlSources), 0)
|
||||||
})
|
})
|
||||||
t.Run("empty_fsys", func(t *testing.T) {
|
t.Run("empty_fsys", func(t *testing.T) {
|
||||||
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil)
|
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
check.Number(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
check.Number(t, len(sources.sqlSources), 0)
|
||||||
|
@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"00000_foo.sql": sqlMapFile,
|
"00000_foo.sql": sqlMapFile,
|
||||||
}
|
}
|
||||||
// strict disable - should not error
|
// strict disable - should not error
|
||||||
sources, err := collectFilesystemSources(mapFS, false, nil)
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
check.Number(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
check.Number(t, len(sources.sqlSources), 0)
|
||||||
// strict enabled - should error
|
// strict enabled - should error
|
||||||
_, err = collectFilesystemSources(mapFS, true, nil)
|
_, err = collectFilesystemSources(mapFS, true, nil, nil)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "migration version must be greater than zero")
|
check.Contains(t, err.Error(), "migration version must be greater than zero")
|
||||||
})
|
})
|
||||||
t.Run("collect", func(t *testing.T) {
|
t.Run("collect", func(t *testing.T) {
|
||||||
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 4)
|
check.Number(t, len(sources.sqlSources), 4)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
check.Number(t, len(sources.goSources), 0)
|
||||||
|
@ -76,6 +76,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"00002_bar.sql": true,
|
"00002_bar.sql": true,
|
||||||
"00110_qux.sql": true,
|
"00110_qux.sql": true,
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 2)
|
check.Number(t, len(sources.sqlSources), 2)
|
||||||
|
@ -96,7 +97,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
_, err = collectFilesystemSources(fsys, true, nil)
|
_, err = collectFilesystemSources(fsys, true, nil, nil)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
||||||
})
|
})
|
||||||
|
@ -108,7 +109,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"4_qux.sql": sqlMapFile,
|
"4_qux.sql": sqlMapFile,
|
||||||
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
||||||
}
|
}
|
||||||
sources, err := collectFilesystemSources(mapFS, false, nil)
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 4)
|
check.Number(t, len(sources.sqlSources), 4)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
check.Number(t, len(sources.goSources), 0)
|
||||||
|
@ -123,7 +124,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
|
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
|
||||||
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
||||||
}
|
}
|
||||||
sources, err := collectFilesystemSources(mapFS, false, nil)
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 2)
|
check.Number(t, len(sources.sqlSources), 2)
|
||||||
check.Number(t, len(sources.goSources), 1)
|
check.Number(t, len(sources.goSources), 1)
|
||||||
|
@ -142,7 +143,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"001_foo.sql": sqlMapFile,
|
"001_foo.sql": sqlMapFile,
|
||||||
"01_bar.sql": sqlMapFile,
|
"01_bar.sql": sqlMapFile,
|
||||||
}
|
}
|
||||||
_, err := collectFilesystemSources(mapFS, false, nil)
|
_, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||||
})
|
})
|
||||||
|
@ -158,7 +159,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
f, err := fs.Sub(mapFS, dirpath)
|
f, err := fs.Sub(mapFS, dirpath)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
got, err := collectFilesystemSources(f, false, nil)
|
got, err := collectFilesystemSources(f, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(got.sqlSources), len(sqlSources))
|
check.Number(t, len(got.sqlSources), len(sqlSources))
|
||||||
check.Number(t, len(got.goSources), 0)
|
check.Number(t, len(got.goSources), 0)
|
||||||
|
@ -194,27 +195,21 @@ func TestMerge(t *testing.T) {
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Equal(t, len(sources.sqlSources), 1)
|
check.Equal(t, len(sources.sqlSources), 1)
|
||||||
check.Equal(t, len(sources.goSources), 2)
|
check.Equal(t, len(sources.goSources), 2)
|
||||||
src1 := sources.lookup(TypeSQL, 1)
|
|
||||||
check.Bool(t, src1 != nil, true)
|
|
||||||
src2 := sources.lookup(TypeGo, 2)
|
|
||||||
check.Bool(t, src2 != nil, true)
|
|
||||||
src3 := sources.lookup(TypeGo, 3)
|
|
||||||
check.Bool(t, src3 != nil, true)
|
|
||||||
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
migrations, err := merge(sources, map[int64]*goMigration{
|
registered := map[int64]*Migration{
|
||||||
2: newGoMigration("", nil, nil),
|
2: NewGoMigration(2, nil, nil),
|
||||||
3: newGoMigration("", nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
})
|
}
|
||||||
|
migrations, err := merge(sources, registered)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 3)
|
check.Number(t, len(migrations), 3)
|
||||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
assertMigration(t, migrations[1], newSource(TypeGo, "", 2))
|
||||||
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
|
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||||
})
|
})
|
||||||
t.Run("unregistered_all", func(t *testing.T) {
|
t.Run("unregistered_all", func(t *testing.T) {
|
||||||
_, err := merge(sources, nil)
|
_, err := merge(sources, nil)
|
||||||
|
@ -224,18 +219,16 @@ func TestMerge(t *testing.T) {
|
||||||
check.Contains(t, err.Error(), "00003_baz.go")
|
check.Contains(t, err.Error(), "00003_baz.go")
|
||||||
})
|
})
|
||||||
t.Run("unregistered_some", func(t *testing.T) {
|
t.Run("unregistered_some", func(t *testing.T) {
|
||||||
_, err := merge(sources, map[int64]*goMigration{
|
_, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)})
|
||||||
2: newGoMigration("", nil, nil),
|
|
||||||
})
|
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
||||||
check.Contains(t, err.Error(), "00003_baz.go")
|
check.Contains(t, err.Error(), "00003_baz.go")
|
||||||
})
|
})
|
||||||
t.Run("duplicate_sql", func(t *testing.T) {
|
t.Run("duplicate_sql", func(t *testing.T) {
|
||||||
_, err := merge(sources, map[int64]*goMigration{
|
_, err := merge(sources, map[int64]*Migration{
|
||||||
1: newGoMigration("", nil, nil), // duplicate. SQL already exists.
|
1: NewGoMigration(1, nil, nil), // duplicate. SQL already exists.
|
||||||
2: newGoMigration("", nil, nil),
|
2: NewGoMigration(2, nil, nil),
|
||||||
3: newGoMigration("", nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||||
|
@ -250,13 +243,13 @@ func TestMerge(t *testing.T) {
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
t.Run("unregistered_all", func(t *testing.T) {
|
t.Run("unregistered_all", func(t *testing.T) {
|
||||||
migrations, err := merge(sources, map[int64]*goMigration{
|
migrations, err := merge(sources, map[int64]*Migration{
|
||||||
3: newGoMigration("", nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
// 4 is missing
|
// 4 is missing
|
||||||
6: newGoMigration("", nil, nil),
|
6: NewGoMigration(6, nil, nil),
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 5)
|
check.Number(t, len(migrations), 5)
|
||||||
|
@ -274,20 +267,20 @@ func TestMerge(t *testing.T) {
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
t.Run("unregistered_all", func(t *testing.T) {
|
t.Run("unregistered_all", func(t *testing.T) {
|
||||||
migrations, err := merge(sources, map[int64]*goMigration{
|
migrations, err := merge(sources, map[int64]*Migration{
|
||||||
// This is the only Go file on disk.
|
// This is the only Go file on disk.
|
||||||
2: newGoMigration("", nil, nil),
|
2: NewGoMigration(2, nil, nil),
|
||||||
// These are not on disk. Explicitly registered.
|
// These are not on disk. Explicitly registered.
|
||||||
3: newGoMigration("", nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
6: newGoMigration("", nil, nil),
|
6: NewGoMigration(6, nil, nil),
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 4)
|
check.Number(t, len(migrations), 4)
|
||||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
assertMigration(t, migrations[1], newSource(TypeGo, "", 2))
|
||||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||||
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
|
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
|
||||||
})
|
})
|
||||||
|
@ -308,15 +301,15 @@ func TestCheckMissingMigrations(t *testing.T) {
|
||||||
{Version: 5},
|
{Version: 5},
|
||||||
{Version: 7}, // <-- database max version_id
|
{Version: 7}, // <-- database max version_id
|
||||||
}
|
}
|
||||||
fsMigrations := []*migration{
|
fsMigrations := []*Migration{
|
||||||
newMigrationVersion(1),
|
newSQLMigration(Source{Version: 1}),
|
||||||
newMigrationVersion(2), // missing migration
|
newSQLMigration(Source{Version: 2}), // missing migration
|
||||||
newMigrationVersion(3),
|
newSQLMigration(Source{Version: 3}),
|
||||||
newMigrationVersion(4),
|
newSQLMigration(Source{Version: 4}),
|
||||||
newMigrationVersion(5),
|
newSQLMigration(Source{Version: 5}),
|
||||||
newMigrationVersion(6), // missing migration
|
newSQLMigration(Source{Version: 6}), // missing migration
|
||||||
newMigrationVersion(7), // ----- database max version_id -----
|
newSQLMigration(Source{Version: 7}), // ----- database max version_id -----
|
||||||
newMigrationVersion(8), // new migration
|
newSQLMigration(Source{Version: 8}), // new migration
|
||||||
}
|
}
|
||||||
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
||||||
check.Number(t, len(got), 2)
|
check.Number(t, len(got), 2)
|
||||||
|
@ -334,9 +327,9 @@ func TestCheckMissingMigrations(t *testing.T) {
|
||||||
{Version: 5},
|
{Version: 5},
|
||||||
{Version: 2},
|
{Version: 2},
|
||||||
}
|
}
|
||||||
fsMigrations := []*migration{
|
fsMigrations := []*Migration{
|
||||||
newMigrationVersion(3), // new migration
|
NewGoMigration(3, nil, nil), // new migration
|
||||||
newMigrationVersion(4), // new migration
|
NewGoMigration(4, nil, nil), // new migration
|
||||||
}
|
}
|
||||||
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
||||||
check.Number(t, len(got), 2)
|
check.Number(t, len(got), 2)
|
||||||
|
@ -345,24 +338,19 @@ func TestCheckMissingMigrations(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMigrationVersion(version int64) *migration {
|
func assertMigration(t *testing.T, got *Migration, want Source) {
|
||||||
return &migration{
|
|
||||||
Source: Source{
|
|
||||||
Version: version,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertMigration(t *testing.T, got *migration, want Source) {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.Source, want)
|
check.Equal(t, got.Type, want.Type)
|
||||||
switch got.Source.Type {
|
check.Equal(t, got.Version, want.Version)
|
||||||
|
check.Equal(t, got.Source, want.Path)
|
||||||
|
switch got.Type {
|
||||||
case TypeGo:
|
case TypeGo:
|
||||||
check.Bool(t, got.Go != nil, true)
|
check.Bool(t, got.goUp != nil, true)
|
||||||
|
check.Bool(t, got.goDown != nil, true)
|
||||||
case TypeSQL:
|
case TypeSQL:
|
||||||
check.Bool(t, got.SQL == nil, true)
|
check.Bool(t, got.sql.Parsed, false)
|
||||||
default:
|
default:
|
||||||
t.Fatalf("unknown migration type: %s", got.Source.Type)
|
t.Fatalf("unknown migration type: %s", got.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package provider
|
package goose
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -16,8 +16,8 @@ var (
|
||||||
// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
|
// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
|
||||||
ErrNoMigrations = errors.New("no migrations found")
|
ErrNoMigrations = errors.New("no migrations found")
|
||||||
|
|
||||||
// ErrNoNextVersion when the next migration version is not found.
|
// errInvalidVersion is returned when a migration version is invalid.
|
||||||
ErrNoNextVersion = errors.New("no next version found")
|
errInvalidVersion = errors.New("version must be greater than 0")
|
||||||
)
|
)
|
||||||
|
|
||||||
// PartialError is returned when a migration fails, but some migrations already got applied.
|
// PartialError is returned when a migration fails, but some migrations already got applied.
|
|
@ -1,8 +1,6 @@
|
||||||
package provider
|
package goose
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -12,12 +10,11 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// DefaultTablename is the default name of the database table used to track history of applied
|
// DefaultTablename is the default name of the database table used to track history of applied
|
||||||
// migrations. It can be overridden using the [WithTableName] option when creating a new
|
// migrations.
|
||||||
// provider.
|
|
||||||
DefaultTablename = "goose_db_version"
|
DefaultTablename = "goose_db_version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProviderOption is a configuration option for a goose provider.
|
// ProviderOption is a configuration option for a goose goose.
|
||||||
type ProviderOption interface {
|
type ProviderOption interface {
|
||||||
apply(*config) error
|
apply(*config) error
|
||||||
}
|
}
|
||||||
|
@ -84,84 +81,75 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithExcludes excludes the given file names from the list of migrations.
|
// WithExcludeNames excludes the given file name from the list of migrations. If called multiple
|
||||||
//
|
// times, the list of excludes is merged.
|
||||||
// If WithExcludes is called multiple times, the list of excludes is merged.
|
func WithExcludeNames(excludes []string) ProviderOption {
|
||||||
func WithExcludes(excludes []string) ProviderOption {
|
|
||||||
return configFunc(func(c *config) error {
|
return configFunc(func(c *config) error {
|
||||||
for _, name := range excludes {
|
for _, name := range excludes {
|
||||||
c.excludes[name] = true
|
if _, ok := c.excludePaths[name]; ok {
|
||||||
|
return fmt.Errorf("duplicate exclude file name: %s", name)
|
||||||
|
}
|
||||||
|
c.excludePaths[name] = true
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoMigrationFunc is a user-defined Go migration, registered using the option [WithGoMigration].
|
// WithExcludeVersions excludes the given versions from the list of migrations. If called multiple
|
||||||
type GoMigrationFunc struct {
|
// times, the list of excludes is merged.
|
||||||
// One of the following must be set:
|
func WithExcludeVersions(versions []int64) ProviderOption {
|
||||||
Run func(context.Context, *sql.Tx) error
|
|
||||||
// -- OR --
|
|
||||||
RunNoTx func(context.Context, *sql.DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithGoMigration registers a Go migration with the given version.
|
|
||||||
//
|
|
||||||
// If WithGoMigration is called multiple times with the same version, an error is returned. Both up
|
|
||||||
// and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be
|
|
||||||
// set.
|
|
||||||
func WithGoMigration(version int64, up, down *GoMigrationFunc) ProviderOption {
|
|
||||||
return configFunc(func(c *config) error {
|
return configFunc(func(c *config) error {
|
||||||
|
for _, version := range versions {
|
||||||
if version < 1 {
|
if version < 1 {
|
||||||
return errors.New("version must be greater than zero")
|
return errInvalidVersion
|
||||||
}
|
}
|
||||||
if _, ok := c.registered[version]; ok {
|
if _, ok := c.excludeVersions[version]; ok {
|
||||||
return fmt.Errorf("go migration with version %d already registered", version)
|
return fmt.Errorf("duplicate excludes version: %d", version)
|
||||||
}
|
}
|
||||||
// Allow nil up/down functions. This enables users to apply "no-op" migrations, while
|
c.excludeVersions[version] = true
|
||||||
// versioning them.
|
|
||||||
if up != nil {
|
|
||||||
if up.Run == nil && up.RunNoTx == nil {
|
|
||||||
return fmt.Errorf("go migration with version %d must have an up function", version)
|
|
||||||
}
|
|
||||||
if up.Run != nil && up.RunNoTx != nil {
|
|
||||||
return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if down != nil {
|
|
||||||
if down.Run == nil && down.RunNoTx == nil {
|
|
||||||
return fmt.Errorf("go migration with version %d must have a down function", version)
|
|
||||||
}
|
|
||||||
if down.Run != nil && down.RunNoTx != nil {
|
|
||||||
return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.registered[version] = &goMigration{
|
|
||||||
up: up,
|
|
||||||
down: down,
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithAllowedMissing allows the provider to apply missing (out-of-order) migrations. By default,
|
// WithGoMigrations registers Go migrations with the provider. If a Go migration with the same
|
||||||
|
// version has already been registered, an error will be returned.
|
||||||
|
//
|
||||||
|
// Go migrations must be constructed using the [NewGoMigration] function.
|
||||||
|
func WithGoMigrations(migrations ...*Migration) ProviderOption {
|
||||||
|
return configFunc(func(c *config) error {
|
||||||
|
for _, m := range migrations {
|
||||||
|
if _, ok := c.registered[m.Version]; ok {
|
||||||
|
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||||
|
}
|
||||||
|
if err := checkGoMigration(m); err != nil {
|
||||||
|
return fmt.Errorf("invalid go migration: %w", err)
|
||||||
|
}
|
||||||
|
c.registered[m.Version] = m
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAllowOutofOrder allows the provider to apply missing (out-of-order) migrations. By default,
|
||||||
// goose will raise an error if it encounters a missing migration.
|
// goose will raise an error if it encounters a missing migration.
|
||||||
//
|
//
|
||||||
// Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true,
|
// For example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is
|
||||||
// then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of
|
// true, then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order
|
||||||
// applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed
|
// of applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first,
|
||||||
// by new migrations.
|
// followed by new migrations.
|
||||||
func WithAllowedMissing(b bool) ProviderOption {
|
func WithAllowOutofOrder(b bool) ProviderOption {
|
||||||
return configFunc(func(c *config) error {
|
return configFunc(func(c *config) error {
|
||||||
c.allowMissing = b
|
c.allowMissing = b
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDisabledVersioning disables versioning. Disabling versioning allows applying migrations
|
// WithDisableVersioning disables versioning. Disabling versioning allows applying migrations
|
||||||
// without tracking the versions in the database schema table. Useful for tests, seeding a database
|
// without tracking the versions in the database schema table. Useful for tests, seeding a database
|
||||||
// or running ad-hoc queries. By default, goose will track all versions in the database schema
|
// or running ad-hoc queries. By default, goose will track all versions in the database schema
|
||||||
// table.
|
// table.
|
||||||
func WithDisabledVersioning(b bool) ProviderOption {
|
func WithDisableVersioning(b bool) ProviderOption {
|
||||||
return configFunc(func(c *config) error {
|
return configFunc(func(c *config) error {
|
||||||
c.disableVersioning = b
|
c.disableVersioning = b
|
||||||
return nil
|
return nil
|
||||||
|
@ -172,11 +160,12 @@ type config struct {
|
||||||
store database.Store
|
store database.Store
|
||||||
|
|
||||||
verbose bool
|
verbose bool
|
||||||
excludes map[string]bool
|
excludePaths map[string]bool
|
||||||
|
excludeVersions map[int64]bool
|
||||||
|
|
||||||
// Go migrations registered by the user. These will be merged/resolved with migrations from the
|
// Go migrations registered by the user. These will be merged/resolved against the globally
|
||||||
// filesystem and init() functions.
|
// registered migrations.
|
||||||
registered map[int64]*goMigration
|
registered map[int64]*Migration
|
||||||
|
|
||||||
// Locking options
|
// Locking options
|
||||||
lockEnabled bool
|
lockEnabled bool
|
|
@ -1,4 +1,4 @@
|
||||||
package provider_test
|
package goose_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
@ -6,9 +6,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"testing/fstest"
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/database"
|
"github.com/pressly/goose/v3/database"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/pressly/goose/v3/internal/check"
|
||||||
"github.com/pressly/goose/v3/internal/provider"
|
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,45 +24,42 @@ func TestNewProvider(t *testing.T) {
|
||||||
}
|
}
|
||||||
t.Run("invalid", func(t *testing.T) {
|
t.Run("invalid", func(t *testing.T) {
|
||||||
// Empty dialect not allowed
|
// Empty dialect not allowed
|
||||||
_, err = provider.NewProvider("", db, fsys)
|
_, err = goose.NewProvider("", db, fsys)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
// Invalid dialect not allowed
|
// Invalid dialect not allowed
|
||||||
_, err = provider.NewProvider("unknown-dialect", db, fsys)
|
_, err = goose.NewProvider("unknown-dialect", db, fsys)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
// Nil db not allowed
|
// Nil db not allowed
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, nil, fsys)
|
_, err = goose.NewProvider(database.DialectSQLite3, nil, fsys)
|
||||||
check.HasError(t, err)
|
|
||||||
// Nil fsys not allowed
|
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil)
|
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
// Nil store not allowed
|
// Nil store not allowed
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(nil))
|
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(nil))
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
// Cannot set both dialect and store
|
// Cannot set both dialect and store
|
||||||
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
|
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithStore(store))
|
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(store))
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
// Multiple stores not allowed
|
// Multiple stores not allowed
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, nil,
|
_, err = goose.NewProvider(database.DialectSQLite3, db, nil,
|
||||||
provider.WithStore(store),
|
goose.WithStore(store),
|
||||||
provider.WithStore(store),
|
goose.WithStore(store),
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
})
|
})
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
// Valid dialect, db, and fsys allowed
|
// Valid dialect, db, and fsys allowed
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys)
|
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
// Valid dialect, db, fsys, and verbose allowed
|
// Valid dialect, db, fsys, and verbose allowed
|
||||||
_, err = provider.NewProvider(database.DialectSQLite3, db, fsys,
|
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys,
|
||||||
provider.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
// Custom store allowed
|
// Custom store allowed
|
||||||
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
|
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
_, err = provider.NewProvider("", db, nil, provider.WithStore(store))
|
_, err = goose.NewProvider("", db, nil, goose.WithStore(store))
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
|
@ -0,0 +1,460 @@
|
||||||
|
package goose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3/database"
|
||||||
|
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||||
|
"go.uber.org/multierr"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMissingZeroVersion = errors.New("missing zero version migration")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Provider) resolveUpMigrations(
|
||||||
|
dbVersions []*database.ListMigrationsResult,
|
||||||
|
version int64,
|
||||||
|
) ([]*Migration, error) {
|
||||||
|
var apply []*Migration
|
||||||
|
var dbMaxVersion int64
|
||||||
|
// dbAppliedVersions is a map of all applied migrations in the database.
|
||||||
|
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
|
||||||
|
for _, m := range dbVersions {
|
||||||
|
dbAppliedVersions[m.Version] = true
|
||||||
|
if m.Version > dbMaxVersion {
|
||||||
|
dbMaxVersion = m.Version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
missingMigrations := checkMissingMigrations(dbVersions, p.migrations)
|
||||||
|
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
|
||||||
|
// migrations entirely. At the moment this is not supported, but leaving this comment because
|
||||||
|
// that's where that logic would be handled.
|
||||||
|
//
|
||||||
|
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not
|
||||||
|
// sure if this is a common use case, but it's possible.
|
||||||
|
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
||||||
|
var collected []string
|
||||||
|
for _, v := range missingMigrations {
|
||||||
|
collected = append(collected, fmt.Sprintf("%d", v.versionID))
|
||||||
|
}
|
||||||
|
msg := "migration"
|
||||||
|
if len(collected) > 1 {
|
||||||
|
msg += "s"
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
|
||||||
|
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for _, v := range missingMigrations {
|
||||||
|
m, err := p.getMigration(v.versionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
apply = append(apply, m)
|
||||||
|
}
|
||||||
|
// filter all migrations with a version greater than the supplied version (min) and less than or
|
||||||
|
// equal to the requested version (max). Skip any migrations that have already been applied.
|
||||||
|
for _, m := range p.migrations {
|
||||||
|
if dbAppliedVersions[m.Version] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.Version > dbMaxVersion && m.Version <= version {
|
||||||
|
apply = append(apply, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return apply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) prepareMigration(fsys fs.FS, m *Migration, direction bool) error {
|
||||||
|
switch m.Type {
|
||||||
|
case TypeGo:
|
||||||
|
if m.goUp.Mode == 0 {
|
||||||
|
return errors.New("go up migration mode is not set")
|
||||||
|
}
|
||||||
|
if m.goDown.Mode == 0 {
|
||||||
|
return errors.New("go down migration mode is not set")
|
||||||
|
}
|
||||||
|
var useTx bool
|
||||||
|
if direction {
|
||||||
|
useTx = m.goUp.Mode == TransactionEnabled
|
||||||
|
} else {
|
||||||
|
useTx = m.goDown.Mode == TransactionEnabled
|
||||||
|
}
|
||||||
|
// bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB,
|
||||||
|
// but are locking the database with *sql.Conn. If the caller sets max open connections to
|
||||||
|
// 1, then this will deadlock because the Go migration will try to acquire a connection from
|
||||||
|
// the pool, but the pool is exhausted because the lock is held.
|
||||||
|
//
|
||||||
|
// A potential solution is to expose a third Go register function *sql.Conn. Or continue to
|
||||||
|
// use *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is
|
||||||
|
// a bit of an edge case. For now, we guard against this scenario by checking the max open
|
||||||
|
// connections and returning an error.
|
||||||
|
if p.cfg.lockEnabled && p.cfg.sessionLocker != nil && p.db.Stats().MaxOpenConnections == 1 {
|
||||||
|
if !useTx {
|
||||||
|
return errors.New("potential deadlock detected: cannot run Go migration without a transaction when max open connections set to 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case TypeSQL:
|
||||||
|
if m.sql.Parsed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.sql.Parsed = true
|
||||||
|
m.sql.UseTx = parsed.UseTx
|
||||||
|
m.sql.Up, m.sql.Down = parsed.Up, parsed.Down
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid migration type: %+v", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMigrations runs migrations sequentially in the given direction. If the migrations list is
|
||||||
|
// empty, return nil without error.
|
||||||
|
func (p *Provider) runMigrations(
|
||||||
|
ctx context.Context,
|
||||||
|
conn *sql.Conn,
|
||||||
|
migrations []*Migration,
|
||||||
|
direction sqlparser.Direction,
|
||||||
|
byOne bool,
|
||||||
|
) ([]*MigrationResult, error) {
|
||||||
|
if len(migrations) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
apply := migrations
|
||||||
|
if byOne {
|
||||||
|
apply = migrations[:1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQL migrations are lazily parsed in both directions. This is done before attempting to run
|
||||||
|
// any migrations to catch errors early and prevent leaving the database in an incomplete state.
|
||||||
|
|
||||||
|
for _, m := range apply {
|
||||||
|
if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// feat(mf): If we decide to add support for advisory locks at the transaction level, this may
|
||||||
|
// be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe
|
||||||
|
// to run in a transaction.
|
||||||
|
|
||||||
|
// feat(mf): this is where we can (optionally) group multiple migrations to be run in a single
|
||||||
|
// transaction. The default is to apply each migration sequentially on its own. See the
|
||||||
|
// following issues for more details:
|
||||||
|
// - https://github.com/pressly/goose/issues/485
|
||||||
|
// - https://github.com/pressly/goose/issues/222
|
||||||
|
//
|
||||||
|
// Be careful, we can't use a single transaction for all migrations because some may be marked
|
||||||
|
// as not using a transaction.
|
||||||
|
|
||||||
|
var results []*MigrationResult
|
||||||
|
for _, m := range apply {
|
||||||
|
current := &MigrationResult{
|
||||||
|
Source: &Source{
|
||||||
|
Type: m.Type,
|
||||||
|
Path: m.Source,
|
||||||
|
Version: m.Version,
|
||||||
|
},
|
||||||
|
Direction: direction.String(),
|
||||||
|
Empty: isEmpty(m, direction.ToBool()),
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
if err := p.runIndividually(ctx, conn, m, direction.ToBool()); err != nil {
|
||||||
|
// TODO(mf): we should also return the pending migrations here, the remaining items in
|
||||||
|
// the apply slice.
|
||||||
|
current.Error = err
|
||||||
|
current.Duration = time.Since(start)
|
||||||
|
return nil, &PartialError{
|
||||||
|
Applied: results,
|
||||||
|
Failed: current,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current.Duration = time.Since(start)
|
||||||
|
results = append(results, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) runIndividually(
|
||||||
|
ctx context.Context,
|
||||||
|
conn *sql.Conn,
|
||||||
|
m *Migration,
|
||||||
|
direction bool,
|
||||||
|
) error {
|
||||||
|
useTx, err := useTx(m, direction)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if useTx {
|
||||||
|
return beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||||
|
if err := runMigration(ctx, tx, m, direction); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return p.maybeInsertOrDelete(ctx, tx, m.Version, direction)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
switch m.Type {
|
||||||
|
case TypeGo:
|
||||||
|
// Note, we are using *sql.DB instead of *sql.Conn because it's the Go migration contract.
|
||||||
|
// This may be a deadlock scenario if max open connections is set to 1 AND a lock is
|
||||||
|
// acquired on the database. In this case, the migration will block forever unable to
|
||||||
|
// acquire a connection from the pool.
|
||||||
|
//
|
||||||
|
// For now, we guard against this scenario by checking the max open connections and
|
||||||
|
// returning an error in the prepareMigration function.
|
||||||
|
if err := runMigration(ctx, p.db, m, direction); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return p.maybeInsertOrDelete(ctx, p.db, m.Version, direction)
|
||||||
|
case TypeSQL:
|
||||||
|
if err := runMigration(ctx, conn, m, direction); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return p.maybeInsertOrDelete(ctx, conn, m.Version, direction)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to run individual migration: neither sql or go: %v", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) maybeInsertOrDelete(
|
||||||
|
ctx context.Context,
|
||||||
|
db database.DBTxConn,
|
||||||
|
version int64,
|
||||||
|
direction bool,
|
||||||
|
) error {
|
||||||
|
// If versioning is disabled, we don't need to insert or delete the migration version.
|
||||||
|
if p.cfg.disableVersioning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if direction {
|
||||||
|
return p.store.Insert(ctx, db, database.InsertRequest{Version: version})
|
||||||
|
}
|
||||||
|
return p.store.Delete(ctx, db, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// beginTx begins a transaction and runs the given function. If the function returns an error, the
|
||||||
|
// transaction is rolled back. Otherwise, the transaction is committed.
|
||||||
|
func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (retErr error) {
|
||||||
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if retErr != nil {
|
||||||
|
retErr = multierr.Append(retErr, tx.Rollback())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := fn(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
conn, err := p.db.Conn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
// cleanup is a function that cleans up the connection, and optionally, the session lock.
|
||||||
|
cleanup := func() error {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return conn.Close()
|
||||||
|
}
|
||||||
|
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
|
||||||
|
if err := l.SessionLock(ctx, conn); err != nil {
|
||||||
|
return nil, nil, multierr.Append(err, cleanup())
|
||||||
|
}
|
||||||
|
// A lock was acquired, so we need to unlock the session when we're done. This is done by
|
||||||
|
// returning a cleanup function that unlocks the session and closes the connection.
|
||||||
|
cleanup = func() error {
|
||||||
|
p.mu.Unlock()
|
||||||
|
// Use a detached context to unlock the session. This is because the context passed to
|
||||||
|
// SessionLock may have been canceled, and we don't want to cancel the unlock.
|
||||||
|
//
|
||||||
|
// TODO(mf): use context.WithoutCancel added in go1.21
|
||||||
|
detachedCtx := context.Background()
|
||||||
|
return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
|
||||||
|
// need the version table because no versions are being recorded.
|
||||||
|
if !p.cfg.disableVersioning {
|
||||||
|
if err := p.ensureVersionTable(ctx, conn); err != nil {
|
||||||
|
return nil, nil, multierr.Append(err, cleanup())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
|
||||||
|
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
|
||||||
|
// from a table that may not exist. https://github.com/pressly/goose/issues/461
|
||||||
|
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||||
|
if err == nil && res != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||||
|
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if p.cfg.disableVersioning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type missingMigration struct {
|
||||||
|
versionID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
|
||||||
|
// migration is one that has a version less than the max version in the database.
|
||||||
|
func checkMissingMigrations(
|
||||||
|
dbMigrations []*database.ListMigrationsResult,
|
||||||
|
fsMigrations []*Migration,
|
||||||
|
) []missingMigration {
|
||||||
|
existing := make(map[int64]bool)
|
||||||
|
var dbMaxVersion int64
|
||||||
|
for _, m := range dbMigrations {
|
||||||
|
existing[m.Version] = true
|
||||||
|
if m.Version > dbMaxVersion {
|
||||||
|
dbMaxVersion = m.Version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var missing []missingMigration
|
||||||
|
for _, m := range fsMigrations {
|
||||||
|
version := m.Version
|
||||||
|
if !existing[version] && version < dbMaxVersion {
|
||||||
|
missing = append(missing, missingMigration{
|
||||||
|
versionID: version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(missing, func(i, j int) bool {
|
||||||
|
return missing[i].versionID < missing[j].versionID
|
||||||
|
})
|
||||||
|
return missing
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMigration returns the migration for the given version. If no migration is found, then
|
||||||
|
// ErrVersionNotFound is returned.
|
||||||
|
func (p *Provider) getMigration(version int64) (*Migration, error) {
|
||||||
|
for _, m := range p.migrations {
|
||||||
|
if m.Version == version {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrVersionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// useTx is a helper function that returns true if the migration should be run in a transaction. It
|
||||||
|
// must only be called after the migration has been parsed and initialized.
|
||||||
|
func useTx(m *Migration, direction bool) (bool, error) {
|
||||||
|
switch m.Type {
|
||||||
|
case TypeGo:
|
||||||
|
if m.goUp.Mode == 0 || m.goDown.Mode == 0 {
|
||||||
|
return false, fmt.Errorf("go migrations must have a mode set")
|
||||||
|
}
|
||||||
|
if direction {
|
||||||
|
return m.goUp.Mode == TransactionEnabled, nil
|
||||||
|
}
|
||||||
|
return m.goDown.Mode == TransactionEnabled, nil
|
||||||
|
case TypeSQL:
|
||||||
|
if !m.sql.Parsed {
|
||||||
|
return false, fmt.Errorf("sql migrations must be parsed")
|
||||||
|
}
|
||||||
|
return m.sql.UseTx, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("use tx: invalid migration type: %q", m.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEmpty is a helper function that returns true if the migration has no functions or no statements
|
||||||
|
// to execute. It must only be called after the migration has been parsed and initialized.
|
||||||
|
func isEmpty(m *Migration, direction bool) bool {
|
||||||
|
switch m.Type {
|
||||||
|
case TypeGo:
|
||||||
|
if direction {
|
||||||
|
return m.goUp.RunTx == nil && m.goUp.RunDB == nil
|
||||||
|
}
|
||||||
|
return m.goDown.RunTx == nil && m.goDown.RunDB == nil
|
||||||
|
case TypeSQL:
|
||||||
|
if direction {
|
||||||
|
return len(m.sql.Up) == 0
|
||||||
|
}
|
||||||
|
return len(m.sql.Down) == 0
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMigration is a helper function that runs the migration in the given direction. It must only be
|
||||||
|
// called after the migration has been parsed and initialized.
|
||||||
|
func runMigration(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
|
||||||
|
switch m.Type {
|
||||||
|
case TypeGo:
|
||||||
|
return runGo(ctx, db, m, direction)
|
||||||
|
case TypeSQL:
|
||||||
|
return runSQL(ctx, db, m, direction)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid migration type: %q", m.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGo is a helper function that runs the given Go functions in the given direction. It must only
|
||||||
|
// be called after the migration has been initialized.
|
||||||
|
func runGo(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
|
||||||
|
switch db := db.(type) {
|
||||||
|
case *sql.Conn:
|
||||||
|
return fmt.Errorf("go migrations are not supported with *sql.Conn")
|
||||||
|
case *sql.DB:
|
||||||
|
if direction && m.goUp.RunDB != nil {
|
||||||
|
return m.goUp.RunDB(ctx, db)
|
||||||
|
}
|
||||||
|
if !direction && m.goDown.RunDB != nil {
|
||||||
|
return m.goDown.RunDB(ctx, db)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *sql.Tx:
|
||||||
|
if direction && m.goUp.RunTx != nil {
|
||||||
|
return m.goUp.RunTx(ctx, db)
|
||||||
|
}
|
||||||
|
if !direction && m.goDown.RunTx != nil {
|
||||||
|
return m.goDown.RunTx(ctx, db)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid database connection type: %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSQL is a helper function that runs the given SQL statements in the given direction. It must
|
||||||
|
// only be called after the migration has been parsed.
|
||||||
|
func runSQL(ctx context.Context, db database.DBTxConn, m *Migration, direction bool) error {
|
||||||
|
if !m.sql.Parsed {
|
||||||
|
return fmt.Errorf("sql migrations must be parsed")
|
||||||
|
}
|
||||||
|
var statements []string
|
||||||
|
if direction {
|
||||||
|
statements = m.sql.Up
|
||||||
|
} else {
|
||||||
|
statements = m.sql.Down
|
||||||
|
}
|
||||||
|
for _, stmt := range statements {
|
||||||
|
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package provider_test
|
package goose_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -16,9 +16,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"testing/fstest"
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/database"
|
"github.com/pressly/goose/v3/database"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/pressly/goose/v3/internal/check"
|
||||||
"github.com/pressly/goose/v3/internal/provider"
|
|
||||||
"github.com/pressly/goose/v3/internal/testdb"
|
"github.com/pressly/goose/v3/internal/testdb"
|
||||||
"github.com/pressly/goose/v3/lock"
|
"github.com/pressly/goose/v3/lock"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
@ -45,22 +45,22 @@ func TestProviderRun(t *testing.T) {
|
||||||
p, _ := newProviderWithDB(t)
|
p, _ := newProviderWithDB(t)
|
||||||
_, err := p.ApplyVersion(context.Background(), 999, true)
|
_, err := p.ApplyVersion(context.Background(), 999, true)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true)
|
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
|
||||||
_, err = p.ApplyVersion(context.Background(), 999, false)
|
_, err = p.ApplyVersion(context.Background(), 999, false)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true)
|
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
|
||||||
})
|
})
|
||||||
t.Run("run_zero", func(t *testing.T) {
|
t.Run("run_zero", func(t *testing.T) {
|
||||||
p, _ := newProviderWithDB(t)
|
p, _ := newProviderWithDB(t)
|
||||||
_, err := p.UpTo(context.Background(), 0)
|
_, err := p.UpTo(context.Background(), 0)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
|
check.Equal(t, err.Error(), "version must be greater than 0")
|
||||||
_, err = p.DownTo(context.Background(), -1)
|
_, err = p.DownTo(context.Background(), -1)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
|
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
|
||||||
_, err = p.ApplyVersion(context.Background(), 0, true)
|
_, err = p.ApplyVersion(context.Background(), 0, true)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
|
check.Equal(t, err.Error(), "version must be greater than 0")
|
||||||
})
|
})
|
||||||
t.Run("up_and_down_all", func(t *testing.T) {
|
t.Run("up_and_down_all", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -72,30 +72,30 @@ func TestProviderRun(t *testing.T) {
|
||||||
check.Number(t, len(sources), numCount)
|
check.Number(t, len(sources), numCount)
|
||||||
// Ensure only SQL migrations are returned
|
// Ensure only SQL migrations are returned
|
||||||
for _, s := range sources {
|
for _, s := range sources {
|
||||||
check.Equal(t, s.Type, provider.TypeSQL)
|
check.Equal(t, s.Type, goose.TypeSQL)
|
||||||
}
|
}
|
||||||
// Test Up
|
// Test Up
|
||||||
res, err := p.Up(ctx)
|
res, err := p.Up(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(res), numCount)
|
check.Number(t, len(res), numCount)
|
||||||
assertResult(t, res[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||||
assertResult(t, res[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||||
assertResult(t, res[2], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
||||||
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false)
|
assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "up", false)
|
||||||
assertResult(t, res[4], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false)
|
assertResult(t, res[4], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "up", false)
|
||||||
assertResult(t, res[5], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true)
|
assertResult(t, res[5], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "up", true)
|
||||||
assertResult(t, res[6], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
||||||
// Test Down
|
// Test Down
|
||||||
res, err = p.DownTo(ctx, 0)
|
res, err = p.DownTo(ctx, 0)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(res), numCount)
|
check.Number(t, len(res), numCount)
|
||||||
assertResult(t, res[0], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
||||||
assertResult(t, res[1], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
||||||
assertResult(t, res[2], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
||||||
assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false)
|
assertResult(t, res[3], newSource(goose.TypeSQL, "00004_insert_data.sql", 4), "down", false)
|
||||||
assertResult(t, res[4], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false)
|
assertResult(t, res[4], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "down", false)
|
||||||
assertResult(t, res[5], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false)
|
assertResult(t, res[5], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "down", false)
|
||||||
assertResult(t, res[6], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false)
|
assertResult(t, res[6], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "down", false)
|
||||||
})
|
})
|
||||||
t.Run("up_and_down_by_one", func(t *testing.T) {
|
t.Run("up_and_down_by_one", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -107,8 +107,8 @@ func TestProviderRun(t *testing.T) {
|
||||||
res, err := p.UpByOne(ctx)
|
res, err := p.UpByOne(ctx)
|
||||||
counter++
|
counter++
|
||||||
if counter > maxVersion {
|
if counter > maxVersion {
|
||||||
if !errors.Is(err, provider.ErrNoNextVersion) {
|
if !errors.Is(err, goose.ErrNoNextVersion) {
|
||||||
t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion)
|
t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -126,8 +126,8 @@ func TestProviderRun(t *testing.T) {
|
||||||
res, err := p.Down(ctx)
|
res, err := p.Down(ctx)
|
||||||
counter++
|
counter++
|
||||||
if counter > maxVersion {
|
if counter > maxVersion {
|
||||||
if !errors.Is(err, provider.ErrNoNextVersion) {
|
if !errors.Is(err, goose.ErrNoNextVersion) {
|
||||||
t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion)
|
t.Fatalf("incorrect error: got:%v want:%v", err, goose.ErrNoNextVersion)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -149,14 +149,14 @@ func TestProviderRun(t *testing.T) {
|
||||||
results, err := p.UpTo(ctx, upToVersion)
|
results, err := p.UpTo(ctx, upToVersion)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(results), upToVersion)
|
check.Number(t, len(results), upToVersion)
|
||||||
assertResult(t, results[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||||
assertResult(t, results[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||||
// Fetch the goose version from DB
|
// Fetch the goose version from DB
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, currentVersion, upToVersion)
|
check.Number(t, currentVersion, upToVersion)
|
||||||
// Validate the version actually matches what goose claims it is
|
// Validate the version actually matches what goose claims it is
|
||||||
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
|
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, gotVersion, upToVersion)
|
check.Number(t, gotVersion, upToVersion)
|
||||||
})
|
})
|
||||||
|
@ -197,7 +197,7 @@ func TestProviderRun(t *testing.T) {
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version)
|
check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version)
|
||||||
// Validate the db migration version actually matches what goose claims it is
|
// Validate the db migration version actually matches what goose claims it is
|
||||||
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
|
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, gotVersion, currentVersion)
|
check.Number(t, gotVersion, currentVersion)
|
||||||
tables, err := getTableNames(db)
|
tables, err := getTableNames(db)
|
||||||
|
@ -213,13 +213,13 @@ func TestProviderRun(t *testing.T) {
|
||||||
downResult, err := p.DownTo(ctx, 0)
|
downResult, err := p.DownTo(ctx, 0)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(downResult), len(sources))
|
check.Number(t, len(downResult), len(sources))
|
||||||
gotVersion, err := getMaxVersionID(db, provider.DefaultTablename)
|
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, gotVersion, 0)
|
check.Number(t, gotVersion, 0)
|
||||||
// Should only be left with a single table, the default goose table
|
// Should only be left with a single table, the default goose table
|
||||||
tables, err := getTableNames(db)
|
tables, err := getTableNames(db)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
knownTables := []string{provider.DefaultTablename, "sqlite_sequence"}
|
knownTables := []string{goose.DefaultTablename, "sqlite_sequence"}
|
||||||
if !reflect.DeepEqual(tables, knownTables) {
|
if !reflect.DeepEqual(tables, knownTables) {
|
||||||
t.Logf("got tables: %v", tables)
|
t.Logf("got tables: %v", tables)
|
||||||
t.Logf("known tables: %v", knownTables)
|
t.Logf("known tables: %v", knownTables)
|
||||||
|
@ -261,7 +261,7 @@ func TestProviderRun(t *testing.T) {
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
_, err = p.ApplyVersion(ctx, 1, true)
|
_, err = p.ApplyVersion(ctx, 1, true)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true)
|
check.Bool(t, errors.Is(err, goose.ErrAlreadyApplied), true)
|
||||||
check.Contains(t, err.Error(), "version 1: already applied")
|
check.Contains(t, err.Error(), "version 1: already applied")
|
||||||
})
|
})
|
||||||
t.Run("status", func(t *testing.T) {
|
t.Run("status", func(t *testing.T) {
|
||||||
|
@ -272,26 +272,26 @@ func TestProviderRun(t *testing.T) {
|
||||||
status, err := p.Status(ctx)
|
status, err := p.Status(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(status), numCount)
|
check.Number(t, len(status), numCount)
|
||||||
assertStatus(t, status[0], provider.StatePending, newSource(provider.TypeSQL, "00001_users_table.sql", 1), true)
|
assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true)
|
||||||
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), true)
|
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true)
|
||||||
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), true)
|
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true)
|
||||||
assertStatus(t, status[3], provider.StatePending, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), true)
|
assertStatus(t, status[3], goose.StatePending, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), true)
|
||||||
assertStatus(t, status[4], provider.StatePending, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), true)
|
assertStatus(t, status[4], goose.StatePending, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), true)
|
||||||
assertStatus(t, status[5], provider.StatePending, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), true)
|
assertStatus(t, status[5], goose.StatePending, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), true)
|
||||||
assertStatus(t, status[6], provider.StatePending, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true)
|
assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true)
|
||||||
// Apply all migrations
|
// Apply all migrations
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
status, err = p.Status(ctx)
|
status, err = p.Status(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(status), numCount)
|
check.Number(t, len(status), numCount)
|
||||||
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
|
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
||||||
assertStatus(t, status[1], provider.StateApplied, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), false)
|
assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false)
|
||||||
assertStatus(t, status[2], provider.StateApplied, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), false)
|
assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false)
|
||||||
assertStatus(t, status[3], provider.StateApplied, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), false)
|
assertStatus(t, status[3], goose.StateApplied, newSource(goose.TypeSQL, "00004_insert_data.sql", 4), false)
|
||||||
assertStatus(t, status[4], provider.StateApplied, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), false)
|
assertStatus(t, status[4], goose.StateApplied, newSource(goose.TypeSQL, "00005_posts_view.sql", 5), false)
|
||||||
assertStatus(t, status[5], provider.StateApplied, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), false)
|
assertStatus(t, status[5], goose.StateApplied, newSource(goose.TypeSQL, "00006_empty_up.sql", 6), false)
|
||||||
assertStatus(t, status[6], provider.StateApplied, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false)
|
assertStatus(t, status[6], goose.StateApplied, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), false)
|
||||||
})
|
})
|
||||||
t.Run("tx_partial_errors", func(t *testing.T) {
|
t.Run("tx_partial_errors", func(t *testing.T) {
|
||||||
countOwners := func(db *sql.DB) (int, error) {
|
countOwners := func(db *sql.DB) (int, error) {
|
||||||
|
@ -321,22 +321,22 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2');
|
||||||
INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
||||||
`),
|
`),
|
||||||
}
|
}
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, mapFS)
|
p, err := goose.NewProvider(database.DialectSQLite3, db, mapFS)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
|
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
|
||||||
var expected *provider.PartialError
|
var expected *goose.PartialError
|
||||||
check.Bool(t, errors.As(err, &expected), true)
|
check.Bool(t, errors.As(err, &expected), true)
|
||||||
// Check Err field
|
// Check Err field
|
||||||
check.Bool(t, expected.Err != nil, true)
|
check.Bool(t, expected.Err != nil, true)
|
||||||
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
|
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||||
// Check Results field
|
// Check Results field
|
||||||
check.Number(t, len(expected.Applied), 1)
|
check.Number(t, len(expected.Applied), 1)
|
||||||
assertResult(t, expected.Applied[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||||
// Check Failed field
|
// Check Failed field
|
||||||
check.Bool(t, expected.Failed != nil, true)
|
check.Bool(t, expected.Failed != nil, true)
|
||||||
assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2)
|
assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2)
|
||||||
check.Bool(t, expected.Failed.Empty, false)
|
check.Bool(t, expected.Failed.Empty, false)
|
||||||
check.Bool(t, expected.Failed.Error != nil, true)
|
check.Bool(t, expected.Failed.Error != nil, true)
|
||||||
check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)")
|
check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||||
|
@ -351,9 +351,9 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
||||||
status, err := p.Status(ctx)
|
status, err := p.Status(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(status), 3)
|
check.Number(t, len(status), 3)
|
||||||
assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false)
|
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
||||||
assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_partial_error.sql", 2), true)
|
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true)
|
||||||
assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_insert_data.sql", 3), true)
|
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -415,7 +415,7 @@ func TestConcurrentProvider(t *testing.T) {
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, currentVersion, maxVersion)
|
check.Number(t, currentVersion, maxVersion)
|
||||||
|
|
||||||
ch := make(chan []*provider.MigrationResult)
|
ch := make(chan []*goose.MigrationResult)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < maxVersion; i++ {
|
for i := 0; i < maxVersion; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -435,8 +435,8 @@ func TestConcurrentProvider(t *testing.T) {
|
||||||
close(ch)
|
close(ch)
|
||||||
}()
|
}()
|
||||||
var (
|
var (
|
||||||
valid [][]*provider.MigrationResult
|
valid [][]*goose.MigrationResult
|
||||||
empty [][]*provider.MigrationResult
|
empty [][]*goose.MigrationResult
|
||||||
)
|
)
|
||||||
for results := range ch {
|
for results := range ch {
|
||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
|
@ -486,9 +486,9 @@ func TestNoVersioning(t *testing.T) {
|
||||||
// These are owners created by migration files.
|
// These are owners created by migration files.
|
||||||
wantOwnerCount = 4
|
wantOwnerCount = 4
|
||||||
)
|
)
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
|
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
|
||||||
provider.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
provider.WithDisabledVersioning(false), // This is the default.
|
goose.WithDisableVersioning(false), // This is the default.
|
||||||
)
|
)
|
||||||
check.Number(t, len(p.ListSources()), 3)
|
check.Number(t, len(p.ListSources()), 3)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
@ -499,9 +499,9 @@ func TestNoVersioning(t *testing.T) {
|
||||||
check.Number(t, baseVersion, 3)
|
check.Number(t, baseVersion, 3)
|
||||||
t.Run("seed-up-down-to-zero", func(t *testing.T) {
|
t.Run("seed-up-down-to-zero", func(t *testing.T) {
|
||||||
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
|
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, fsys,
|
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
|
||||||
provider.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
provider.WithDisabledVersioning(true), // Provider with no versioning.
|
goose.WithDisableVersioning(true), // Provider with no versioning.
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, len(p.ListSources()), 2)
|
check.Number(t, len(p.ListSources()), 2)
|
||||||
|
@ -552,8 +552,8 @@ func TestAllowMissing(t *testing.T) {
|
||||||
|
|
||||||
t.Run("missing_now_allowed", func(t *testing.T) {
|
t.Run("missing_now_allowed", func(t *testing.T) {
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
|
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
|
||||||
provider.WithAllowedMissing(false),
|
goose.WithAllowOutofOrder(false),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
|
||||||
|
@ -607,8 +607,8 @@ func TestAllowMissing(t *testing.T) {
|
||||||
|
|
||||||
t.Run("missing_allowed", func(t *testing.T) {
|
t.Run("missing_allowed", func(t *testing.T) {
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(),
|
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
|
||||||
provider.WithAllowedMissing(true),
|
goose.WithAllowOutofOrder(true),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
|
||||||
|
@ -640,7 +640,7 @@ func TestAllowMissing(t *testing.T) {
|
||||||
check.Bool(t, upResult != nil, true)
|
check.Bool(t, upResult != nil, true)
|
||||||
check.Number(t, upResult.Source.Version, 6)
|
check.Number(t, upResult.Source.Version, 6)
|
||||||
|
|
||||||
count, err := getGooseVersionCount(db, provider.DefaultTablename)
|
count, err := getGooseVersionCount(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
check.Number(t, count, 6)
|
check.Number(t, count, 6)
|
||||||
current, err := p.GetDBVersion(ctx)
|
current, err := p.GetDBVersion(ctx)
|
||||||
|
@ -676,7 +676,7 @@ func TestAllowMissing(t *testing.T) {
|
||||||
testDownAndVersion(1, 1)
|
testDownAndVersion(1, 1)
|
||||||
_, err = p.Down(ctx)
|
_, err = p.Down(ctx)
|
||||||
check.HasError(t, err)
|
check.HasError(t, err)
|
||||||
check.Bool(t, errors.Is(err, provider.ErrNoNextVersion), true)
|
check.Bool(t, errors.Is(err, goose.ErrNoNextVersion), true)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -691,6 +691,7 @@ func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoOnly(t *testing.T) {
|
func TestGoOnly(t *testing.T) {
|
||||||
|
t.Cleanup(goose.ResetGlobalMigrations)
|
||||||
// Not parallel because each subtest modifies global state.
|
// Not parallel because each subtest modifies global state.
|
||||||
|
|
||||||
countUser := func(db *sql.DB) int {
|
countUser := func(db *sql.DB) int {
|
||||||
|
@ -703,99 +704,109 @@ func TestGoOnly(t *testing.T) {
|
||||||
|
|
||||||
t.Run("with_tx", func(t *testing.T) {
|
t.Run("with_tx", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
register := []*provider.MigrationCopy{
|
register := []*goose.Migration{
|
||||||
{
|
goose.NewGoMigration(
|
||||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
1,
|
||||||
UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
&goose.GoFunc{RunTx: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)")},
|
||||||
DownFnContext: newTxFn("DROP TABLE users"),
|
&goose.GoFunc{RunTx: newTxFn("DROP TABLE users")},
|
||||||
},
|
),
|
||||||
}
|
}
|
||||||
err := provider.SetGlobalGoMigrations(register)
|
err := goose.SetGlobalMigrations(register...)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
t.Cleanup(goose.ResetGlobalMigrations)
|
||||||
|
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
|
register = []*goose.Migration{
|
||||||
provider.WithGoMigration(
|
goose.NewGoMigration(
|
||||||
2,
|
2,
|
||||||
&provider.GoMigrationFunc{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
&goose.GoFunc{RunTx: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
||||||
&provider.GoMigrationFunc{Run: newTxFn("DELETE FROM users")},
|
&goose.GoFunc{RunTx: newTxFn("DELETE FROM users")},
|
||||||
),
|
),
|
||||||
|
}
|
||||||
|
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
|
||||||
|
goose.WithGoMigrations(register...),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.Number(t, len(p.ListSources()), 2)
|
check.Number(t, len(p.ListSources()), 2)
|
||||||
assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1)
|
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
||||||
assertSource(t, sources[1], provider.TypeGo, "", 2)
|
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
||||||
// Apply migration 1
|
// Apply migration 1
|
||||||
res, err := p.UpByOne(ctx)
|
res, err := p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
||||||
check.Number(t, countUser(db), 0)
|
check.Number(t, countUser(db), 0)
|
||||||
check.Bool(t, tableExists(t, db, "users"), true)
|
check.Bool(t, tableExists(t, db, "users"), true)
|
||||||
// Apply migration 2
|
// Apply migration 2
|
||||||
res, err = p.UpByOne(ctx)
|
res, err = p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
||||||
check.Number(t, countUser(db), 3)
|
check.Number(t, countUser(db), 3)
|
||||||
// Rollback migration 2
|
// Rollback migration 2
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
||||||
check.Number(t, countUser(db), 0)
|
check.Number(t, countUser(db), 0)
|
||||||
// Rollback migration 1
|
// Rollback migration 1
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
||||||
// Check table does not exist
|
// Check table does not exist
|
||||||
check.Bool(t, tableExists(t, db, "users"), false)
|
check.Bool(t, tableExists(t, db, "users"), false)
|
||||||
})
|
})
|
||||||
t.Run("with_db", func(t *testing.T) {
|
t.Run("with_db", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
register := []*provider.MigrationCopy{
|
register := []*goose.Migration{
|
||||||
{
|
goose.NewGoMigration(
|
||||||
Version: 1, Source: "00001_users_table.go", Registered: true,
|
1,
|
||||||
UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
&goose.GoFunc{
|
||||||
DownFnNoTxContext: newDBFn("DROP TABLE users"),
|
RunDB: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"),
|
||||||
},
|
},
|
||||||
|
&goose.GoFunc{
|
||||||
|
RunDB: newDBFn("DROP TABLE users"),
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
err := provider.SetGlobalGoMigrations(register)
|
err := goose.SetGlobalMigrations(register...)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
t.Cleanup(provider.ResetGlobalGoMigrations)
|
t.Cleanup(goose.ResetGlobalMigrations)
|
||||||
|
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, nil,
|
register = []*goose.Migration{
|
||||||
provider.WithGoMigration(
|
goose.NewGoMigration(
|
||||||
2,
|
2,
|
||||||
&provider.GoMigrationFunc{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
&goose.GoFunc{RunDB: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")},
|
||||||
&provider.GoMigrationFunc{RunNoTx: newDBFn("DELETE FROM users")},
|
&goose.GoFunc{RunDB: newDBFn("DELETE FROM users")},
|
||||||
),
|
),
|
||||||
|
}
|
||||||
|
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
|
||||||
|
goose.WithGoMigrations(register...),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.Number(t, len(p.ListSources()), 2)
|
check.Number(t, len(p.ListSources()), 2)
|
||||||
assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1)
|
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
||||||
assertSource(t, sources[1], provider.TypeGo, "", 2)
|
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
||||||
// Apply migration 1
|
// Apply migration 1
|
||||||
res, err := p.UpByOne(ctx)
|
res, err := p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
||||||
check.Number(t, countUser(db), 0)
|
check.Number(t, countUser(db), 0)
|
||||||
check.Bool(t, tableExists(t, db, "users"), true)
|
check.Bool(t, tableExists(t, db, "users"), true)
|
||||||
// Apply migration 2
|
// Apply migration 2
|
||||||
res, err = p.UpByOne(ctx)
|
res, err = p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
||||||
check.Number(t, countUser(db), 3)
|
check.Number(t, countUser(db), 3)
|
||||||
// Rollback migration 2
|
// Rollback migration 2
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
||||||
check.Number(t, countUser(db), 0)
|
check.Number(t, countUser(db), 0)
|
||||||
// Rollback migration 1
|
// Rollback migration 1
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
||||||
// Check table does not exist
|
// Check table does not exist
|
||||||
check.Bool(t, tableExists(t, db, "users"), false)
|
check.Bool(t, tableExists(t, db, "users"), false)
|
||||||
})
|
})
|
||||||
|
@ -818,16 +829,23 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
newProvider := func() *provider.Provider {
|
newProvider := func() *goose.Provider {
|
||||||
sessionLocker, err := lock.NewPostgresSessionLocker()
|
|
||||||
check.NoError(t, err)
|
sessionLocker, err := lock.NewPostgresSessionLocker(
|
||||||
p, err := provider.NewProvider(database.DialectPostgres, db, os.DirFS("../../testdata/migrations"),
|
lock.WithLockTimeout(5, 60), // Timeout 5min. Try every 5s up to 60 times.
|
||||||
provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
|
|
||||||
provider.WithVerbose(testing.Verbose()),
|
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
|
p, err := goose.NewProvider(
|
||||||
|
database.DialectPostgres,
|
||||||
|
db,
|
||||||
|
os.DirFS("testdata/migrations"),
|
||||||
|
goose.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
|
||||||
|
)
|
||||||
|
check.NoError(t, err)
|
||||||
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
provider1 := newProvider()
|
provider1 := newProvider()
|
||||||
provider2 := newProvider()
|
provider2 := newProvider()
|
||||||
|
|
||||||
|
@ -891,7 +909,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||||
for {
|
for {
|
||||||
result, err := provider1.UpByOne(context.Background())
|
result, err := provider1.UpByOne(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -907,7 +925,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||||
for {
|
for {
|
||||||
result, err := provider2.UpByOne(context.Background())
|
result, err := provider2.UpByOne(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -993,7 +1011,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||||
for {
|
for {
|
||||||
result, err := provider1.Down(context.Background())
|
result, err := provider1.Down(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -1009,7 +1027,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||||
for {
|
for {
|
||||||
result, err := provider2.Down(context.Background())
|
result, err := provider2.Down(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, provider.ErrNoNextVersion) {
|
if errors.Is(err, goose.ErrNoNextVersion) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -1068,14 +1086,14 @@ func randomAlphaNumeric(length int) string {
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) {
|
func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provider, *sql.DB) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
opts = append(
|
opts = append(
|
||||||
opts,
|
opts,
|
||||||
provider.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
)
|
)
|
||||||
p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
|
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
|
||||||
check.NoError(t, err)
|
check.NoError(t, err)
|
||||||
return p, db
|
return p, db
|
||||||
}
|
}
|
||||||
|
@ -1118,14 +1136,14 @@ func getTableNames(db *sql.DB) ([]string, error) {
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) {
|
func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.State, state)
|
check.Equal(t, got.State, state)
|
||||||
check.Equal(t, got.Source, source)
|
check.Equal(t, got.Source, source)
|
||||||
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
|
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
|
func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Bool(t, got != nil, true)
|
check.Bool(t, got != nil, true)
|
||||||
check.Equal(t, got.Source, source)
|
check.Equal(t, got.Source, source)
|
||||||
|
@ -1135,21 +1153,15 @@ func assertResult(t *testing.T, got *provider.MigrationResult, source provider.S
|
||||||
check.Bool(t, got.Duration > 0, true)
|
check.Bool(t, got.Duration > 0, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) {
|
func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.Type, typ)
|
check.Equal(t, got.Type, typ)
|
||||||
check.Equal(t, got.Path, name)
|
check.Equal(t, got.Path, name)
|
||||||
check.Equal(t, got.Version, version)
|
check.Equal(t, got.Version, version)
|
||||||
switch got.Type {
|
|
||||||
case provider.TypeGo:
|
|
||||||
check.Equal(t, got.Type.String(), "go")
|
|
||||||
case provider.TypeSQL:
|
|
||||||
check.Equal(t, got.Type.String(), "sql")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSource(t provider.MigrationType, fullpath string, version int64) provider.Source {
|
func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source {
|
||||||
return provider.Source{
|
return &goose.Source{
|
||||||
Type: t,
|
Type: t,
|
||||||
Path: fullpath,
|
Path: fullpath,
|
||||||
Version: version,
|
Version: version,
|
|
@ -0,0 +1,79 @@
|
||||||
|
package goose_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"io/fs"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
"github.com/pressly/goose/v3/database"
|
||||||
|
"github.com/pressly/goose/v3/internal/check"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProvider(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||||
|
check.NoError(t, err)
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
_, err := goose.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
|
||||||
|
check.HasError(t, err)
|
||||||
|
check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true)
|
||||||
|
})
|
||||||
|
|
||||||
|
mapFS := fstest.MapFS{
|
||||||
|
"migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)},
|
||||||
|
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
||||||
|
}
|
||||||
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
|
check.NoError(t, err)
|
||||||
|
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys)
|
||||||
|
check.NoError(t, err)
|
||||||
|
sources := p.ListSources()
|
||||||
|
check.Equal(t, len(sources), 2)
|
||||||
|
check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1))
|
||||||
|
check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
migration1 = `
|
||||||
|
-- +goose Up
|
||||||
|
CREATE TABLE foo (id INTEGER PRIMARY KEY);
|
||||||
|
-- +goose Down
|
||||||
|
DROP TABLE foo;
|
||||||
|
`
|
||||||
|
migration2 = `
|
||||||
|
-- +goose Up
|
||||||
|
ALTER TABLE foo ADD COLUMN name TEXT;
|
||||||
|
-- +goose Down
|
||||||
|
ALTER TABLE foo DROP COLUMN name;
|
||||||
|
`
|
||||||
|
migration3 = `
|
||||||
|
-- +goose Up
|
||||||
|
CREATE TABLE bar (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
description TEXT
|
||||||
|
);
|
||||||
|
-- +goose Down
|
||||||
|
DROP TABLE bar;
|
||||||
|
`
|
||||||
|
migration4 = `
|
||||||
|
-- +goose Up
|
||||||
|
-- Rename the 'foo' table to 'my_foo'
|
||||||
|
ALTER TABLE foo RENAME TO my_foo;
|
||||||
|
|
||||||
|
-- Add a new column 'timestamp' to 'my_foo'
|
||||||
|
ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
-- Remove the 'timestamp' column from 'my_foo'
|
||||||
|
ALTER TABLE my_foo DROP COLUMN timestamp;
|
||||||
|
|
||||||
|
-- Rename the 'my_foo' table back to 'foo'
|
||||||
|
ALTER TABLE my_foo RENAME TO foo;
|
||||||
|
`
|
||||||
|
)
|
|
@ -1,30 +1,15 @@
|
||||||
package provider
|
package goose
|
||||||
|
|
||||||
import (
|
import "time"
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MigrationType is the type of migration.
|
// MigrationType is the type of migration.
|
||||||
type MigrationType int
|
type MigrationType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TypeGo MigrationType = iota + 1
|
TypeGo MigrationType = "go"
|
||||||
TypeSQL
|
TypeSQL MigrationType = "sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t MigrationType) String() string {
|
|
||||||
switch t {
|
|
||||||
case TypeGo:
|
|
||||||
return "go"
|
|
||||||
case TypeSQL:
|
|
||||||
return "sql"
|
|
||||||
default:
|
|
||||||
// This should never happen.
|
|
||||||
return fmt.Sprintf("unknown (%d)", t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Source represents a single migration source.
|
// Source represents a single migration source.
|
||||||
//
|
//
|
||||||
// The Path field may be empty if the migration was registered manually. This is typically the case
|
// The Path field may be empty if the migration was registered manually. This is typically the case
|
||||||
|
@ -37,7 +22,7 @@ type Source struct {
|
||||||
|
|
||||||
// MigrationResult is the result of a single migration operation.
|
// MigrationResult is the result of a single migration operation.
|
||||||
type MigrationResult struct {
|
type MigrationResult struct {
|
||||||
Source Source
|
Source *Source
|
||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
Direction string
|
Direction string
|
||||||
// Empty indicates no action was taken during the migration, but it was still versioned. For
|
// Empty indicates no action was taken during the migration, but it was still versioned. For
|
||||||
|
@ -64,7 +49,7 @@ const (
|
||||||
|
|
||||||
// MigrationStatus represents the status of a single migration.
|
// MigrationStatus represents the status of a single migration.
|
||||||
type MigrationStatus struct {
|
type MigrationStatus struct {
|
||||||
Source Source
|
Source *Source
|
||||||
State State
|
State State
|
||||||
AppliedAt time.Time
|
AppliedAt time.Time
|
||||||
}
|
}
|
|
@ -66,7 +66,7 @@ func register(filename string, useTx bool, up, down *GoFunc) error {
|
||||||
// We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but
|
// We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but
|
||||||
// we know based on the register function what the user is requesting.
|
// we know based on the register function what the user is requesting.
|
||||||
m.UseTx = useTx
|
m.UseTx = useTx
|
||||||
registeredGoMigrations[v] = &m
|
registeredGoMigrations[v] = m
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
17
types.go
17
types.go
|
@ -1,17 +0,0 @@
|
||||||
package goose
|
|
||||||
|
|
||||||
// MigrationType is the type of migration.
|
|
||||||
type MigrationType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeGo MigrationType = "go"
|
|
||||||
TypeSQL MigrationType = "sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (t MigrationType) String() string {
|
|
||||||
// This should never happen.
|
|
||||||
if t == "" {
|
|
||||||
return "unknown migration type"
|
|
||||||
}
|
|
||||||
return string(t)
|
|
||||||
}
|
|
Loading…
Reference in New Issue