mirror of
https://github.com/pressly/goose.git
synced 2025-05-31 11:42:04 +00:00
feat: Add NewGoMigration constructor (#631)
This commit is contained in:
parent
432a6ac0f8
commit
6d0de39c50
107
globals.go
107
globals.go
@ -3,44 +3,42 @@ package goose
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var (
|
||||
registeredGoMigrations = make(map[int64]*Migration)
|
||||
)
|
||||
|
||||
// ResetGlobalMigrations resets the global go migrations registry.
|
||||
// ResetGlobalMigrations resets the global Go migrations registry.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func ResetGlobalMigrations() {
|
||||
registeredGoMigrations = make(map[int64]*Migration)
|
||||
}
|
||||
|
||||
// SetGlobalMigrations registers go migrations globally. It returns an error if a migration with the
|
||||
// same version has already been registered.
|
||||
//
|
||||
// Source may be empty, but if it is set, it must be a path with a numeric component that matches
|
||||
// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx.
|
||||
// SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the
|
||||
// same version has already been registered. Go migrations must be constructed using the
|
||||
// [NewGoMigration] function.
|
||||
//
|
||||
// Not safe for concurrent use.
|
||||
func SetGlobalMigrations(migrations ...Migration) error {
|
||||
for _, m := range migrations {
|
||||
// make a copy of the migration so we can modify it without affecting the original.
|
||||
if err := validGoMigration(&m); err != nil {
|
||||
return fmt.Errorf("invalid go migration: %w", err)
|
||||
}
|
||||
for _, migration := range migrations {
|
||||
m := &migration
|
||||
if _, ok := registeredGoMigrations[m.Version]; ok {
|
||||
return fmt.Errorf("go migration with version %d already registered", m.Version)
|
||||
}
|
||||
m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user.
|
||||
registeredGoMigrations[m.Version] = &m
|
||||
if err := checkMigration(m); err != nil {
|
||||
return fmt.Errorf("invalid go migration: %w", err)
|
||||
}
|
||||
registeredGoMigrations[m.Version] = m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validGoMigration(m *Migration) error {
|
||||
if m == nil {
|
||||
return errors.New("must not be nil")
|
||||
func checkMigration(m *Migration) error {
|
||||
if !m.construct {
|
||||
return errors.New("must use NewGoMigration to construct migrations")
|
||||
}
|
||||
if !m.Registered {
|
||||
return errors.New("must be registered")
|
||||
@ -52,36 +50,81 @@ func validGoMigration(m *Migration) error {
|
||||
return errors.New("version must be greater than zero")
|
||||
}
|
||||
if m.Source != "" {
|
||||
if filepath.Ext(m.Source) != ".go" {
|
||||
return fmt.Errorf("source must have .go extension: %q", m.Source)
|
||||
}
|
||||
// If the source is set, expect it to be a path with a numeric component that matches the
|
||||
// version. This field is not intended to be used for descriptive purposes.
|
||||
version, err := NumericComponent(m.Source)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("invalid source: %w", err)
|
||||
}
|
||||
if version != m.Version {
|
||||
return fmt.Errorf("numeric component [%d] in go migration does not match version in source %q", m.Version, m.Source)
|
||||
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
|
||||
}
|
||||
}
|
||||
// It's valid for all of these funcs to be nil. Which means version the go migration but do not
|
||||
// run anything.
|
||||
if err := setGoFunc(m.goUp); err != nil {
|
||||
return fmt.Errorf("up function: %w", err)
|
||||
}
|
||||
if err := setGoFunc(m.goDown); err != nil {
|
||||
return fmt.Errorf("down function: %w", err)
|
||||
}
|
||||
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
|
||||
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
}
|
||||
if m.UpFn != nil && m.UpFnNoTx != nil {
|
||||
return errors.New("must specify exactly one of UpFn or UpFnNoTx")
|
||||
}
|
||||
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
|
||||
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
}
|
||||
// Do not allow legacy functions to be set.
|
||||
if m.UpFn != nil {
|
||||
return errors.New("must not specify UpFn")
|
||||
}
|
||||
if m.DownFn != nil {
|
||||
return errors.New("must not specify DownFn")
|
||||
}
|
||||
if m.UpFnNoTx != nil {
|
||||
return errors.New("must not specify UpFnNoTx")
|
||||
}
|
||||
if m.DownFnNoTx != nil {
|
||||
return errors.New("must not specify DownFnNoTx")
|
||||
if m.DownFn != nil && m.DownFnNoTx != nil {
|
||||
return errors.New("must specify exactly one of DownFn or DownFnNoTx")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setGoFunc(f *GoFunc) error {
|
||||
if f == nil {
|
||||
f = &GoFunc{Mode: TransactionEnabled}
|
||||
return nil
|
||||
}
|
||||
if f.RunTx != nil && f.RunDB != nil {
|
||||
return errors.New("must specify exactly one of RunTx or RunDB")
|
||||
}
|
||||
if f.RunTx == nil && f.RunDB == nil {
|
||||
switch f.Mode {
|
||||
case 0:
|
||||
// Default to TransactionEnabled ONLY if mode is not set explicitly.
|
||||
f.Mode = TransactionEnabled
|
||||
case TransactionEnabled, TransactionDisabled:
|
||||
// No functions but mode is set. This is not an error. It means the user wants to record
|
||||
// a version with the given mode but not run any functions.
|
||||
default:
|
||||
return fmt.Errorf("invalid mode: %d", f.Mode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if f.RunDB != nil {
|
||||
switch f.Mode {
|
||||
case 0, TransactionDisabled:
|
||||
f.Mode = TransactionDisabled
|
||||
default:
|
||||
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
|
||||
}
|
||||
}
|
||||
if f.RunTx != nil {
|
||||
switch f.Mode {
|
||||
case 0, TransactionEnabled:
|
||||
f.Mode = TransactionEnabled
|
||||
default:
|
||||
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
|
||||
}
|
||||
}
|
||||
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
|
||||
// the functions or return an error. This should never happen.
|
||||
if f.Mode == 0 {
|
||||
return errors.New("failed to infer transaction mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
341
globals_test.go
341
globals_test.go
@ -1,113 +1,266 @@
|
||||
package goose_test
|
||||
package goose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
)
|
||||
|
||||
func TestGlobalRegister(t *testing.T) {
|
||||
// Avoid polluting other tests and do not run in parallel.
|
||||
t.Cleanup(func() {
|
||||
goose.ResetGlobalMigrations()
|
||||
func TestNewGoMigration(t *testing.T) {
|
||||
t.Run("valid_both_nil", func(t *testing.T) {
|
||||
m := NewGoMigration(1, nil, nil)
|
||||
// roundtrip
|
||||
check.Equal(t, m.Version, int64(1))
|
||||
check.Equal(t, m.Type, TypeGo)
|
||||
check.Equal(t, m.Registered, true)
|
||||
check.Equal(t, m.Next, int64(-1))
|
||||
check.Equal(t, m.Previous, int64(-1))
|
||||
check.Equal(t, m.Source, "")
|
||||
check.Bool(t, m.UpFnNoTxContext == nil, true)
|
||||
check.Bool(t, m.DownFnNoTxContext == nil, true)
|
||||
check.Bool(t, m.UpFnContext == nil, true)
|
||||
check.Bool(t, m.DownFnContext == nil, true)
|
||||
check.Bool(t, m.UpFn == nil, true)
|
||||
check.Bool(t, m.DownFn == nil, true)
|
||||
check.Bool(t, m.UpFnNoTx == nil, true)
|
||||
check.Bool(t, m.DownFnNoTx == nil, true)
|
||||
check.Bool(t, m.goUp != nil, true)
|
||||
check.Bool(t, m.goDown != nil, true)
|
||||
check.Equal(t, m.goUp.Mode, TransactionEnabled)
|
||||
check.Equal(t, m.goDown.Mode, TransactionEnabled)
|
||||
})
|
||||
fnNoTx := func(context.Context, *sql.DB) error { return nil }
|
||||
fn := func(context.Context, *sql.Tx) error { return nil }
|
||||
t.Run("all_set", func(t *testing.T) {
|
||||
// This will eventually be an error when registering migrations.
|
||||
m := NewGoMigration(
|
||||
1,
|
||||
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
|
||||
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
|
||||
)
|
||||
// check only functions
|
||||
check.Bool(t, m.UpFn != nil, true)
|
||||
check.Bool(t, m.UpFnContext != nil, true)
|
||||
check.Bool(t, m.UpFnNoTx != nil, true)
|
||||
check.Bool(t, m.UpFnNoTxContext != nil, true)
|
||||
check.Bool(t, m.DownFn != nil, true)
|
||||
check.Bool(t, m.DownFnContext != nil, true)
|
||||
check.Bool(t, m.DownFnNoTx != nil, true)
|
||||
check.Bool(t, m.DownFnNoTxContext != nil, true)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionMode(t *testing.T) {
|
||||
t.Cleanup(ResetGlobalMigrations)
|
||||
|
||||
runDB := func(context.Context, *sql.DB) error { return nil }
|
||||
runTx := func(context.Context, *sql.Tx) error { return nil }
|
||||
|
||||
err := SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunTx: runTx, RunDB: runDB}, nil), // cannot specify both
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB")
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, RunDB: runDB}), // cannot specify both
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB")
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}, nil), // invalid explicit mode tx
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set")
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}), // invalid explicit mode tx
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set")
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}, nil), // invalid explicit mode no-tx
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set")
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, nil, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}), // invalid explicit mode no-tx
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set")
|
||||
|
||||
t.Run("default_mode", func(t *testing.T) {
|
||||
t.Cleanup(ResetGlobalMigrations)
|
||||
|
||||
m := NewGoMigration(1, nil, nil)
|
||||
err = SetGlobalMigrations(m)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(registeredGoMigrations), 1)
|
||||
registered := registeredGoMigrations[1]
|
||||
check.Bool(t, registered.goUp != nil, true)
|
||||
check.Bool(t, registered.goDown != nil, true)
|
||||
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
|
||||
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
|
||||
|
||||
migration2 := NewGoMigration(2, nil, nil)
|
||||
// reset so we can check the default is set
|
||||
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
|
||||
err = SetGlobalMigrations(migration2)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(registeredGoMigrations), 2)
|
||||
registered = registeredGoMigrations[2]
|
||||
check.Bool(t, registered.goUp != nil, true)
|
||||
check.Bool(t, registered.goDown != nil, true)
|
||||
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
|
||||
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
|
||||
})
|
||||
t.Run("unknown_mode", func(t *testing.T) {
|
||||
m := NewGoMigration(1, nil, nil)
|
||||
m.goUp.Mode, m.goDown.Mode = 3, 3 // reset to default
|
||||
err := SetGlobalMigrations(m)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid mode: 3")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLegacyFunctions(t *testing.T) {
|
||||
t.Cleanup(ResetGlobalMigrations)
|
||||
|
||||
runDB := func(context.Context, *sql.DB) error { return nil }
|
||||
runTx := func(context.Context, *sql.Tx) error { return nil }
|
||||
|
||||
assertMigration := func(t *testing.T, m *Migration, version int64) {
|
||||
t.Helper()
|
||||
check.Equal(t, m.Version, version)
|
||||
check.Equal(t, m.Type, TypeGo)
|
||||
check.Equal(t, m.Registered, true)
|
||||
check.Equal(t, m.Next, int64(-1))
|
||||
check.Equal(t, m.Previous, int64(-1))
|
||||
check.Equal(t, m.Source, "")
|
||||
}
|
||||
|
||||
t.Run("all_tx", func(t *testing.T) {
|
||||
t.Cleanup(ResetGlobalMigrations)
|
||||
err := SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(registeredGoMigrations), 1)
|
||||
m := registeredGoMigrations[1]
|
||||
assertMigration(t, m, 1)
|
||||
// Legacy functions.
|
||||
check.Bool(t, m.UpFnNoTxContext == nil, true)
|
||||
check.Bool(t, m.DownFnNoTxContext == nil, true)
|
||||
// Context-aware functions.
|
||||
check.Bool(t, m.goUp == nil, false)
|
||||
check.Bool(t, m.UpFnContext == nil, false)
|
||||
check.Bool(t, m.goDown == nil, false)
|
||||
check.Bool(t, m.DownFnContext == nil, false)
|
||||
// Always nil
|
||||
check.Bool(t, m.UpFn == nil, false)
|
||||
check.Bool(t, m.DownFn == nil, false)
|
||||
check.Bool(t, m.UpFnNoTx == nil, true)
|
||||
check.Bool(t, m.DownFnNoTx == nil, true)
|
||||
})
|
||||
t.Run("all_db", func(t *testing.T) {
|
||||
t.Cleanup(ResetGlobalMigrations)
|
||||
err := SetGlobalMigrations(
|
||||
NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(registeredGoMigrations), 1)
|
||||
m := registeredGoMigrations[2]
|
||||
assertMigration(t, m, 2)
|
||||
// Legacy functions.
|
||||
check.Bool(t, m.UpFnNoTxContext == nil, false)
|
||||
check.Bool(t, m.goUp == nil, false)
|
||||
check.Bool(t, m.DownFnNoTxContext == nil, false)
|
||||
check.Bool(t, m.goDown == nil, false)
|
||||
// Context-aware functions.
|
||||
check.Bool(t, m.UpFnContext == nil, true)
|
||||
check.Bool(t, m.DownFnContext == nil, true)
|
||||
// Always nil
|
||||
check.Bool(t, m.UpFn == nil, true)
|
||||
check.Bool(t, m.DownFn == nil, true)
|
||||
check.Bool(t, m.UpFnNoTx == nil, false)
|
||||
check.Bool(t, m.DownFnNoTx == nil, false)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGlobalRegister(t *testing.T) {
|
||||
t.Cleanup(ResetGlobalMigrations)
|
||||
|
||||
// runDB := func(context.Context, *sql.DB) error { return nil }
|
||||
runTx := func(context.Context, *sql.Tx) error { return nil }
|
||||
|
||||
// Success.
|
||||
err := goose.SetGlobalMigrations(
|
||||
[]goose.Migration{}...,
|
||||
err := SetGlobalMigrations([]Migration{}...)
|
||||
check.NoError(t, err)
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo, UpFnContext: fn},
|
||||
)
|
||||
check.NoError(t, err)
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo},
|
||||
// Try to register the same migration again.
|
||||
err = SetGlobalMigrations(
|
||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "go migration with version 1 already registered")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{
|
||||
Registered: true,
|
||||
Version: 2,
|
||||
Source: "00002_foo.sql",
|
||||
Type: goose.TypeGo,
|
||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||
},
|
||||
)
|
||||
check.NoError(t, err)
|
||||
// Reset.
|
||||
{
|
||||
goose.ResetGlobalMigrations()
|
||||
}
|
||||
// Failure.
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{},
|
||||
)
|
||||
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must be registered")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, Type: goose.TypeSQL},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 0, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `invalid go migration: numeric component [1] in go migration does not match version in source "2_foo.sql"`)
|
||||
// Legacy functions.
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx")
|
||||
// Context-aware functions.
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, UpFnContext: fn, UpFnNoTxContext: fnNoTx, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, DownFnContext: fn, DownFnNoTxContext: fnNoTx, Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
// Source and version mismatch.
|
||||
err = goose.SetGlobalMigrations(
|
||||
goose.Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: goose.TypeGo},
|
||||
)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`)
|
||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||
}
|
||||
|
||||
func TestCheckMigration(t *testing.T) {
|
||||
// Failures.
|
||||
err := checkMigration(&Migration{})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||
err = checkMigration(&Migration{construct: true})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must be registered")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `type must be "go"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "version must be greater than zero")
|
||||
// Success.
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
|
||||
check.NoError(t, err)
|
||||
// Failures.
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `no filename separator '_' found`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
UpFn: func(*sql.Tx) error { return nil },
|
||||
UpFnNoTx: func(*sql.DB) error { return nil },
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
|
||||
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||
DownFn: func(*sql.Tx) error { return nil },
|
||||
DownFnNoTx: func(*sql.DB) error { return nil },
|
||||
})
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
|
||||
}
|
||||
|
@ -8,13 +8,6 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func NotNil(t *testing.T, v any) {
|
||||
t.Helper()
|
||||
if v == nil {
|
||||
t.Fatal("unexpected nil value")
|
||||
}
|
||||
}
|
||||
|
||||
func NoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
|
@ -113,7 +113,7 @@ func TestProviderRun(t *testing.T) {
|
||||
break
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, res)
|
||||
check.Bool(t, res != nil, true)
|
||||
check.Number(t, res.Source.Version, int64(counter))
|
||||
}
|
||||
currentVersion, err := p.GetDBVersion(ctx)
|
||||
@ -132,7 +132,7 @@ func TestProviderRun(t *testing.T) {
|
||||
break
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, res)
|
||||
check.Bool(t, res != nil, true)
|
||||
check.Number(t, res.Source.Version, int64(maxVersion-counter+1))
|
||||
}
|
||||
// Once everything is tested the version should match the highest testdata version
|
||||
@ -632,12 +632,12 @@ func TestAllowMissing(t *testing.T) {
|
||||
// 4
|
||||
upResult, err := p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, upResult)
|
||||
check.Bool(t, upResult != nil, true)
|
||||
check.Number(t, upResult.Source.Version, 4)
|
||||
// 6
|
||||
upResult, err = p.UpByOne(ctx)
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, upResult)
|
||||
check.Bool(t, upResult != nil, true)
|
||||
check.Number(t, upResult.Source.Version, 6)
|
||||
|
||||
count, err := getGooseVersionCount(db, provider.DefaultTablename)
|
||||
@ -660,7 +660,7 @@ func TestAllowMissing(t *testing.T) {
|
||||
check.Number(t, currentVersion, wantDBVersion)
|
||||
downRes, err := p.Down(ctx)
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, downRes)
|
||||
check.Bool(t, downRes != nil, true)
|
||||
check.Number(t, downRes.Source.Version, wantResultVersion)
|
||||
}
|
||||
|
||||
@ -897,7 +897,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, result)
|
||||
check.Bool(t, result != nil, true)
|
||||
mu.Lock()
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
@ -913,7 +913,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, result)
|
||||
check.Bool(t, result != nil, true)
|
||||
mu.Lock()
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
@ -999,7 +999,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, result)
|
||||
check.Bool(t, result != nil, true)
|
||||
mu.Lock()
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
@ -1015,7 +1015,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
check.NoError(t, err)
|
||||
check.NotNil(t, result)
|
||||
check.Bool(t, result != nil, true)
|
||||
mu.Lock()
|
||||
applied = append(applied, result.Source.Version)
|
||||
mu.Unlock()
|
||||
@ -1127,7 +1127,7 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St
|
||||
|
||||
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
|
||||
t.Helper()
|
||||
check.NotNil(t, got)
|
||||
check.Bool(t, got != nil, true)
|
||||
check.Equal(t, got.Source, source)
|
||||
check.Equal(t, got.Direction, direction)
|
||||
check.Equal(t, got.Empty, isEmpty)
|
||||
|
133
migrate.go
133
migrate.go
@ -8,7 +8,6 @@ import (
|
||||
"io/fs"
|
||||
"math"
|
||||
"path"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@ -125,115 +124,6 @@ func (ms Migrations) String() string {
|
||||
return str
|
||||
}
|
||||
|
||||
// GoMigration is a Go migration func that is run within a transaction.
|
||||
type GoMigration func(tx *sql.Tx) error
|
||||
|
||||
// GoMigrationContext is a Go migration func that is run within a transaction and receives a context.
|
||||
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error
|
||||
|
||||
// GoMigrationNoTx is a Go migration func that is run outside a transaction.
|
||||
type GoMigrationNoTx func(db *sql.DB) error
|
||||
|
||||
// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context.
|
||||
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error
|
||||
|
||||
// AddMigration adds Go migrations.
|
||||
//
|
||||
// Deprecated: Use AddMigrationContext.
|
||||
func AddMigration(up, down GoMigration) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddMigrationContext adds Go migrations.
|
||||
func AddMigrationContext(up, down GoMigrationContext) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationContext(filename, up, down)
|
||||
}
|
||||
|
||||
// AddNamedMigration adds named Go migrations.
|
||||
//
|
||||
// Deprecated: Use AddNamedMigrationContext.
|
||||
func AddNamedMigration(filename string, up, down GoMigration) {
|
||||
AddNamedMigrationContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddNamedMigrationContext adds named Go migrations.
|
||||
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
|
||||
if err := register(filename, true, up, down, nil, nil); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// AddMigrationNoTx adds Go migrations that will be run outside transaction.
|
||||
//
|
||||
// Deprecated: Use AddNamedMigrationNoTxContext.
|
||||
func AddMigrationNoTx(up, down GoMigrationNoTx) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
|
||||
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationNoTxContext(filename, up, down)
|
||||
}
|
||||
|
||||
// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
|
||||
//
|
||||
// Deprecated: Use AddNamedMigrationNoTxContext.
|
||||
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
|
||||
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
|
||||
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
|
||||
if err := register(filename, false, nil, nil, up, down); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func register(
|
||||
filename string,
|
||||
useTx bool,
|
||||
up, down GoMigrationContext,
|
||||
upNoTx, downNoTx GoMigrationNoTxContext,
|
||||
) error {
|
||||
// Sanity check caller did not mix tx and non-tx based functions.
|
||||
if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) {
|
||||
return fmt.Errorf("cannot mix tx and non-tx based go migrations functions")
|
||||
}
|
||||
v, _ := NumericComponent(filename)
|
||||
if existing, ok := registeredGoMigrations[v]; ok {
|
||||
return fmt.Errorf("failed to add migration %q: version %d conflicts with %q",
|
||||
filename,
|
||||
v,
|
||||
existing.Source,
|
||||
)
|
||||
}
|
||||
// Add to global as a registered migration.
|
||||
registeredGoMigrations[v] = &Migration{
|
||||
Version: v,
|
||||
Next: -1,
|
||||
Previous: -1,
|
||||
Registered: true,
|
||||
Source: filename,
|
||||
UseTx: useTx,
|
||||
UpFnContext: up,
|
||||
DownFnContext: down,
|
||||
UpFnNoTxContext: upNoTx,
|
||||
DownFnNoTxContext: downNoTx,
|
||||
// These are deprecated and will be removed in the future.
|
||||
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
|
||||
// Goose does not use these internally anymore and instead uses the context versions.
|
||||
UpFn: withoutContext(up),
|
||||
DownFn: withoutContext(down),
|
||||
UpFnNoTx: withoutContext(upNoTx),
|
||||
DownFnNoTx: withoutContext(downNoTx),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func collectMigrationsFS(
|
||||
fsys fs.FS,
|
||||
dirpath string,
|
||||
@ -388,29 +278,6 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// withContext changes the signature of a function that receives one argument to receive a context and the argument.
|
||||
func withContext[T any](fn func(T) error) func(context.Context, T) error {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(ctx context.Context, t T) error {
|
||||
return fn(t)
|
||||
}
|
||||
}
|
||||
|
||||
// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument.
|
||||
// When called the passed context is always context.Background().
|
||||
func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(t T) error {
|
||||
return fn(context.Background(), t)
|
||||
}
|
||||
}
|
||||
|
||||
// collectGoMigrations collects Go migrations from the filesystem and merges them with registered
|
||||
// migrations.
|
||||
//
|
||||
|
151
migration.go
151
migration.go
@ -13,36 +13,141 @@ import (
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
)
|
||||
|
||||
// NewGoMigration creates a new Go migration.
|
||||
//
|
||||
// Both up and down functions may be nil, in which case the migration will be recorded in the
|
||||
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
|
||||
// a version without running any functions. See [GoFunc] for more details.
|
||||
func NewGoMigration(version int64, up, down *GoFunc) Migration {
|
||||
m := Migration{
|
||||
Type: TypeGo,
|
||||
Registered: true,
|
||||
Version: version,
|
||||
Next: -1, Previous: -1,
|
||||
goUp: up,
|
||||
goDown: down,
|
||||
construct: true,
|
||||
}
|
||||
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
|
||||
// we will remove these fields in favor of [GoFunc].
|
||||
//
|
||||
// Note, this function does not do any validation. Validation is lazily done when the migration
|
||||
// is registered.
|
||||
if up != nil {
|
||||
if up.RunDB != nil {
|
||||
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
|
||||
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
|
||||
}
|
||||
if up.RunTx != nil {
|
||||
m.UseTx = true
|
||||
m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error
|
||||
m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error
|
||||
}
|
||||
}
|
||||
if down != nil {
|
||||
if down.RunDB != nil {
|
||||
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
|
||||
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
|
||||
}
|
||||
if down.RunTx != nil {
|
||||
m.UseTx = true
|
||||
m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error
|
||||
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
|
||||
}
|
||||
}
|
||||
if m.goUp == nil {
|
||||
m.goUp = &GoFunc{Mode: TransactionEnabled}
|
||||
}
|
||||
if m.goDown == nil {
|
||||
m.goDown = &GoFunc{Mode: TransactionEnabled}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Migration struct represents either a SQL or Go migration.
|
||||
//
|
||||
// Avoid constructing migrations manually, use [NewGoMigration] function.
|
||||
type Migration struct {
|
||||
Type MigrationType
|
||||
Version int64
|
||||
// Source is the path to the .sql script or .go file. It may be empty for Go migrations that
|
||||
// have been registered globally and don't have a source file.
|
||||
Source string
|
||||
|
||||
UpFnContext, DownFnContext GoMigrationContext
|
||||
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
|
||||
// These fields are used internally by goose and users are not expected to set them. Instead,
|
||||
// use [NewGoMigration] to create a new go migration.
|
||||
construct bool
|
||||
goUp, goDown *GoFunc
|
||||
|
||||
// These fields will be removed in a future major version. They are here for backwards
|
||||
// compatibility and are an implementation detail.
|
||||
Registered bool
|
||||
UseTx bool
|
||||
Next int64 // next version, or -1 if none
|
||||
Previous int64 // previous version, -1 if none
|
||||
|
||||
// We still save the non-context versions in the struct in case someone is using them. Goose
|
||||
// does not use these internally anymore in favor of the context-aware versions. These fields
|
||||
// will be removed in a future major version.
|
||||
|
||||
UpFn GoMigration // Deprecated: use UpFnContext instead.
|
||||
DownFn GoMigration // Deprecated: use DownFnContext instead.
|
||||
UpFnNoTx GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead.
|
||||
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
|
||||
|
||||
noVersioning bool
|
||||
}
|
||||
|
||||
// GoFunc represents a Go migration function.
|
||||
type GoFunc struct {
|
||||
// Exactly one of these must be set, or both must be nil.
|
||||
RunTx func(ctx context.Context, tx *sql.Tx) error
|
||||
// -- OR --
|
||||
RunDB func(ctx context.Context, db *sql.DB) error
|
||||
|
||||
// Mode is the transaction mode for the migration. When one of the run functions is set, the
|
||||
// mode will be inferred from the function and the field is ignored. Users do not need to set
|
||||
// this field when supplying a run function.
|
||||
//
|
||||
// If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil
|
||||
// functions is to record a version in the version table without invoking a Go migration
|
||||
// function.
|
||||
//
|
||||
// The only time this field is required is if BOTH run functions are nil AND you want to
|
||||
// override the default transaction mode.
|
||||
Mode TransactionMode
|
||||
}
|
||||
|
||||
// TransactionMode represents the possible transaction modes for a migration.
|
||||
type TransactionMode int
|
||||
|
||||
const (
|
||||
TransactionEnabled TransactionMode = iota + 1
|
||||
TransactionDisabled
|
||||
)
|
||||
|
||||
func (m TransactionMode) String() string {
|
||||
switch m {
|
||||
case TransactionEnabled:
|
||||
return "transaction_enabled"
|
||||
case TransactionDisabled:
|
||||
return "transaction_disabled"
|
||||
default:
|
||||
return fmt.Sprintf("unknown transaction mode (%d)", m)
|
||||
}
|
||||
}
|
||||
|
||||
// MigrationRecord struct.
|
||||
//
|
||||
// Deprecated: unused and will be removed in a future major version.
|
||||
type MigrationRecord struct {
|
||||
VersionID int64
|
||||
TStamp time.Time
|
||||
IsApplied bool // was this a result of up() or down()
|
||||
}
|
||||
|
||||
// Migration struct represents either a SQL or Go migration.
|
||||
type Migration struct {
|
||||
Type MigrationType
|
||||
Version int64
|
||||
Source string // path to .sql script or .go file
|
||||
Registered bool
|
||||
UpFnContext, DownFnContext GoMigrationContext
|
||||
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
|
||||
|
||||
// These fields will be removed in a future major version. They are here for backwards
|
||||
// compatibility and are an implementation detail.
|
||||
UseTx bool
|
||||
Next int64 // next version, or -1 if none
|
||||
Previous int64 // previous version, -1 if none
|
||||
|
||||
// We still save the non-context versions in the struct in case someone is using them. Goose
|
||||
// does not use these internally anymore in favor of the context-aware versions.
|
||||
UpFn, DownFn GoMigration
|
||||
UpFnNoTx, DownFnNoTx GoMigrationNoTx
|
||||
|
||||
noVersioning bool
|
||||
}
|
||||
|
||||
func (m *Migration) String() string {
|
||||
return fmt.Sprint(m.Source)
|
||||
}
|
||||
|
133
register.go
Normal file
133
register.go
Normal file
@ -0,0 +1,133 @@
|
||||
package goose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GoMigrationContext is a Go migration func that is run within a transaction and receives a
|
||||
// context.
|
||||
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error
|
||||
|
||||
// AddMigrationContext adds Go migrations.
|
||||
func AddMigrationContext(up, down GoMigrationContext) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationContext(filename, up, down)
|
||||
}
|
||||
|
||||
// AddNamedMigrationContext adds named Go migrations.
|
||||
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
|
||||
if err := register(
|
||||
filename,
|
||||
true,
|
||||
&GoFunc{RunTx: up},
|
||||
&GoFunc{RunTx: down},
|
||||
); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a
|
||||
// context.
|
||||
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error
|
||||
|
||||
// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
|
||||
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationNoTxContext(filename, up, down)
|
||||
}
|
||||
|
||||
// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
|
||||
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
|
||||
if err := register(
|
||||
filename,
|
||||
false,
|
||||
&GoFunc{RunDB: up},
|
||||
&GoFunc{RunDB: down},
|
||||
); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func register(filename string, useTx bool, up, down *GoFunc) error {
|
||||
v, _ := NumericComponent(filename)
|
||||
if existing, ok := registeredGoMigrations[v]; ok {
|
||||
return fmt.Errorf("failed to add migration %q: version %d conflicts with %q",
|
||||
filename,
|
||||
v,
|
||||
existing.Source,
|
||||
)
|
||||
}
|
||||
// Add to global as a registered migration.
|
||||
m := NewGoMigration(v, up, down)
|
||||
m.Source = filename
|
||||
// We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but
|
||||
// we know based on the register function what the user is requesting.
|
||||
m.UseTx = useTx
|
||||
registeredGoMigrations[v] = &m
|
||||
return nil
|
||||
}
|
||||
|
||||
// withContext changes the signature of a function that receives one argument to receive a context
|
||||
// and the argument.
|
||||
func withContext[T any](fn func(T) error) func(context.Context, T) error {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
return func(ctx context.Context, t T) error {
|
||||
return fn(t)
|
||||
}
|
||||
}
|
||||
|
||||
// withoutContext changes the signature of a function that receives a context and one argument to
|
||||
// receive only the argument. When called the passed context is always context.Background().
|
||||
func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
return func(t T) error {
|
||||
return fn(context.Background(), t)
|
||||
}
|
||||
}
|
||||
|
||||
// GoMigration is a Go migration func that is run within a transaction.
|
||||
//
|
||||
// Deprecated: Use GoMigrationContext.
|
||||
type GoMigration func(tx *sql.Tx) error
|
||||
|
||||
// GoMigrationNoTx is a Go migration func that is run outside a transaction.
|
||||
//
|
||||
// Deprecated: Use GoMigrationNoTxContext.
|
||||
type GoMigrationNoTx func(db *sql.DB) error
|
||||
|
||||
// AddMigration adds Go migrations.
|
||||
//
|
||||
// Deprecated: Use AddMigrationContext.
|
||||
func AddMigration(up, down GoMigration) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddNamedMigration adds named Go migrations.
|
||||
//
|
||||
// Deprecated: Use AddNamedMigrationContext.
|
||||
func AddNamedMigration(filename string, up, down GoMigration) {
|
||||
AddNamedMigrationContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddMigrationNoTx adds Go migrations that will be run outside transaction.
|
||||
//
|
||||
// Deprecated: Use AddMigrationNoTxContext.
|
||||
func AddMigrationNoTx(up, down GoMigrationNoTx) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
|
||||
}
|
||||
|
||||
// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
|
||||
//
|
||||
// Deprecated: Use AddNamedMigrationNoTxContext.
|
||||
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
|
||||
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
|
||||
}
|
@ -1,27 +1,30 @@
|
||||
package gomigrations
|
||||
package gomigrations_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/testdb"
|
||||
|
||||
_ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestGoMigrationByOne(t *testing.T) {
|
||||
db, cleanup, err := testdb.NewPostgres()
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
t.Parallel()
|
||||
|
||||
check.NoError(t, goose.SetDialect("sqlite3"))
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
check.NoError(t, err)
|
||||
dir := "testdata"
|
||||
files, err := filepath.Glob(dir + "/*.go")
|
||||
check.NoError(t, err)
|
||||
|
||||
upByOne := func(t *testing.T) int64 {
|
||||
err = goose.UpByOne(db, dir)
|
||||
t.Logf("err: %v %s", err, dir)
|
||||
check.NoError(t, err)
|
||||
version, err := goose.GetDBVersion(db)
|
||||
check.NoError(t, err)
|
||||
@ -42,6 +45,21 @@ func TestGoMigrationByOne(t *testing.T) {
|
||||
check.NoError(t, err)
|
||||
check.Number(t, version, len(files))
|
||||
|
||||
tables, err := ListTables(db)
|
||||
check.NoError(t, err)
|
||||
check.Equal(t, tables, []string{
|
||||
"alpha",
|
||||
"bravo",
|
||||
"charlie",
|
||||
"delta",
|
||||
"echo",
|
||||
"foxtrot",
|
||||
"golf",
|
||||
"goose_db_version",
|
||||
"hotel",
|
||||
"sqlite_sequence",
|
||||
})
|
||||
|
||||
// Migrate all files down-by-one.
|
||||
for i := len(files) - 1; i >= 0; i-- {
|
||||
check.Number(t, downByOne(t), i)
|
||||
@ -49,4 +67,31 @@ func TestGoMigrationByOne(t *testing.T) {
|
||||
version, err = goose.GetDBVersion(db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, version, 0)
|
||||
|
||||
tables, err = ListTables(db)
|
||||
check.NoError(t, err)
|
||||
check.Equal(t, tables, []string{
|
||||
"goose_db_version",
|
||||
"sqlite_sequence",
|
||||
})
|
||||
}
|
||||
|
||||
func ListTables(db *sql.DB) ([]string, error) {
|
||||
rows, err := db.Query(`SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, name)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
@ -1,9 +1,12 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/database"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -11,13 +14,19 @@ func init() {
|
||||
}
|
||||
|
||||
func up001(tx *sql.Tx) error {
|
||||
q := "CREATE TABLE foo (id INT, subid INT, name TEXT)"
|
||||
_, err := tx.Exec(q)
|
||||
return err
|
||||
return createTable(tx, "alpha")
|
||||
}
|
||||
|
||||
func down001(tx *sql.Tx) error {
|
||||
q := "DROP TABLE IF EXISTS foo"
|
||||
_, err := tx.Exec(q)
|
||||
return dropTable(tx, "alpha")
|
||||
}
|
||||
|
||||
func createTable(db database.DBTxConn, name string) error {
|
||||
_, err := db.ExecContext(context.Background(), fmt.Sprintf("CREATE TABLE %s (id INTEGER)", name))
|
||||
return err
|
||||
}
|
||||
|
||||
func dropTable(db database.DBTxConn, name string) error {
|
||||
_, err := db.ExecContext(context.Background(), fmt.Sprintf("DROP TABLE %s", name))
|
||||
return err
|
||||
}
|
||||
|
@ -11,7 +11,5 @@ func init() {
|
||||
}
|
||||
|
||||
func up002(tx *sql.Tx) error {
|
||||
q := "INSERT INTO foo VALUES (1, 1, 'Alice')"
|
||||
_, err := tx.Exec(q)
|
||||
return err
|
||||
return createTable(tx, "bravo")
|
||||
}
|
||||
|
@ -11,7 +11,5 @@ func init() {
|
||||
}
|
||||
|
||||
func down003(tx *sql.Tx) error {
|
||||
q := "TRUNCATE TABLE foo"
|
||||
_, err := tx.Exec(q)
|
||||
return err
|
||||
return dropTable(tx, "bravo")
|
||||
}
|
||||
|
@ -11,13 +11,9 @@ func init() {
|
||||
}
|
||||
|
||||
func up005(db *sql.DB) error {
|
||||
q := "CREATE TABLE users (id INT, email TEXT)"
|
||||
_, err := db.Exec(q)
|
||||
return err
|
||||
return createTable(db, "charlie")
|
||||
}
|
||||
|
||||
func down005(db *sql.DB) error {
|
||||
q := "DROP TABLE IF EXISTS users"
|
||||
_, err := db.Exec(q)
|
||||
return err
|
||||
return dropTable(db, "charlie")
|
||||
}
|
||||
|
@ -11,7 +11,5 @@ func init() {
|
||||
}
|
||||
|
||||
func up006(db *sql.DB) error {
|
||||
q := "INSERT INTO users VALUES (1, 'admin@example.com')"
|
||||
_, err := db.Exec(q)
|
||||
return err
|
||||
return createTable(db, "delta")
|
||||
}
|
||||
|
@ -11,7 +11,5 @@ func init() {
|
||||
}
|
||||
|
||||
func down007(db *sql.DB) error {
|
||||
q := "TRUNCATE TABLE users"
|
||||
_, err := db.Exec(q)
|
||||
return err
|
||||
return dropTable(db, "delta")
|
||||
}
|
||||
|
@ -5,5 +5,5 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigration(nil, nil)
|
||||
goose.AddMigrationNoTx(nil, nil)
|
||||
}
|
||||
|
20
tests/gomigrations/success/testdata/009_up_down_ctx.go
vendored
Normal file
20
tests/gomigrations/success/testdata/009_up_down_ctx.go
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(up009, down009)
|
||||
}
|
||||
|
||||
func up009(ctx context.Context, tx *sql.Tx) error {
|
||||
return createTable(tx, "echo")
|
||||
}
|
||||
|
||||
func down009(ctx context.Context, tx *sql.Tx) error {
|
||||
return dropTable(tx, "echo")
|
||||
}
|
16
tests/gomigrations/success/testdata/010_up_only_ctx.go
vendored
Normal file
16
tests/gomigrations/success/testdata/010_up_only_ctx.go
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(up010, nil)
|
||||
}
|
||||
|
||||
func up010(ctx context.Context, tx *sql.Tx) error {
|
||||
return createTable(tx, "foxtrot")
|
||||
}
|
16
tests/gomigrations/success/testdata/011_down_only_ctx.go
vendored
Normal file
16
tests/gomigrations/success/testdata/011_down_only_ctx.go
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(nil, down011)
|
||||
}
|
||||
|
||||
func down011(ctx context.Context, tx *sql.Tx) error {
|
||||
return dropTable(tx, "foxtrot")
|
||||
}
|
9
tests/gomigrations/success/testdata/012_empty_ctx.go
vendored
Normal file
9
tests/gomigrations/success/testdata/012_empty_ctx.go
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationContext(nil, nil)
|
||||
}
|
20
tests/gomigrations/success/testdata/013_up_down_no_tx_ctx.go
vendored
Normal file
20
tests/gomigrations/success/testdata/013_up_down_no_tx_ctx.go
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationNoTxContext(up013, down013)
|
||||
}
|
||||
|
||||
func up013(ctx context.Context, db *sql.DB) error {
|
||||
return createTable(db, "golf")
|
||||
}
|
||||
|
||||
func down013(ctx context.Context, db *sql.DB) error {
|
||||
return dropTable(db, "golf")
|
||||
}
|
16
tests/gomigrations/success/testdata/014_up_only_no_tx_ctx.go
vendored
Normal file
16
tests/gomigrations/success/testdata/014_up_only_no_tx_ctx.go
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationNoTxContext(up014, nil)
|
||||
}
|
||||
|
||||
func up014(ctx context.Context, db *sql.DB) error {
|
||||
return createTable(db, "hotel")
|
||||
}
|
16
tests/gomigrations/success/testdata/015_down_only_no_tx_ctx.go
vendored
Normal file
16
tests/gomigrations/success/testdata/015_down_only_no_tx_ctx.go
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationNoTxContext(nil, down015)
|
||||
}
|
||||
|
||||
func down015(ctx context.Context, db *sql.DB) error {
|
||||
return dropTable(db, "hotel")
|
||||
}
|
9
tests/gomigrations/success/testdata/016_empty_no_tx_ctx.go
vendored
Normal file
9
tests/gomigrations/success/testdata/016_empty_no_tx_ctx.go
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
package gomigrations
|
||||
|
||||
import (
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
goose.AddMigrationNoTxContext(nil, nil)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user