diff --git a/globals.go b/globals.go index e68bb0c..e2d55fa 100644 --- a/globals.go +++ b/globals.go @@ -3,44 +3,42 @@ package goose import ( "errors" "fmt" + "path/filepath" ) var ( registeredGoMigrations = make(map[int64]*Migration) ) -// ResetGlobalMigrations resets the global go migrations registry. +// ResetGlobalMigrations resets the global Go migrations registry. // // Not safe for concurrent use. func ResetGlobalMigrations() { registeredGoMigrations = make(map[int64]*Migration) } -// SetGlobalMigrations registers go migrations globally. It returns an error if a migration with the -// same version has already been registered. -// -// Source may be empty, but if it is set, it must be a path with a numeric component that matches -// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx. +// SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the +// same version has already been registered. Go migrations must be constructed using the +// [NewGoMigration] function. // // Not safe for concurrent use. func SetGlobalMigrations(migrations ...Migration) error { - for _, m := range migrations { - // make a copy of the migration so we can modify it without affecting the original. - if err := validGoMigration(&m); err != nil { - return fmt.Errorf("invalid go migration: %w", err) - } + for _, migration := range migrations { + m := &migration if _, ok := registeredGoMigrations[m.Version]; ok { return fmt.Errorf("go migration with version %d already registered", m.Version) } - m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user. - registeredGoMigrations[m.Version] = &m + if err := checkMigration(m); err != nil { + return fmt.Errorf("invalid go migration: %w", err) + } + registeredGoMigrations[m.Version] = m } return nil } -func validGoMigration(m *Migration) error { - if m == nil { - return errors.New("must not be nil") +func checkMigration(m *Migration) error { + if !m.construct { + return errors.New("must use NewGoMigration to construct migrations") } if !m.Registered { return errors.New("must be registered") @@ -52,36 +50,81 @@ func validGoMigration(m *Migration) error { return errors.New("version must be greater than zero") } if m.Source != "" { + if filepath.Ext(m.Source) != ".go" { + return fmt.Errorf("source must have .go extension: %q", m.Source) + } // If the source is set, expect it to be a path with a numeric component that matches the // version. This field is not intended to be used for descriptive purposes. version, err := NumericComponent(m.Source) if err != nil { - return err + return fmt.Errorf("invalid source: %w", err) } if version != m.Version { - return fmt.Errorf("numeric component [%d] in go migration does not match version in source %q", m.Version, m.Source) + return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source) } } - // It's valid for all of these funcs to be nil. Which means version the go migration but do not - // run anything. + if err := setGoFunc(m.goUp); err != nil { + return fmt.Errorf("up function: %w", err) + } + if err := setGoFunc(m.goDown); err != nil { + return fmt.Errorf("down function: %w", err) + } if m.UpFnContext != nil && m.UpFnNoTxContext != nil { return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext") } + if m.UpFn != nil && m.UpFnNoTx != nil { + return errors.New("must specify exactly one of UpFn or UpFnNoTx") + } if m.DownFnContext != nil && m.DownFnNoTxContext != nil { return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext") } - // Do not allow legacy functions to be set. - if m.UpFn != nil { - return errors.New("must not specify UpFn") - } - if m.DownFn != nil { - return errors.New("must not specify DownFn") - } - if m.UpFnNoTx != nil { - return errors.New("must not specify UpFnNoTx") - } - if m.DownFnNoTx != nil { - return errors.New("must not specify DownFnNoTx") + if m.DownFn != nil && m.DownFnNoTx != nil { + return errors.New("must specify exactly one of DownFn or DownFnNoTx") + } + return nil +} + +func setGoFunc(f *GoFunc) error { + if f == nil { + f = &GoFunc{Mode: TransactionEnabled} + return nil + } + if f.RunTx != nil && f.RunDB != nil { + return errors.New("must specify exactly one of RunTx or RunDB") + } + if f.RunTx == nil && f.RunDB == nil { + switch f.Mode { + case 0: + // Default to TransactionEnabled ONLY if mode is not set explicitly. + f.Mode = TransactionEnabled + case TransactionEnabled, TransactionDisabled: + // No functions but mode is set. This is not an error. It means the user wants to record + // a version with the given mode but not run any functions. + default: + return fmt.Errorf("invalid mode: %d", f.Mode) + } + return nil + } + if f.RunDB != nil { + switch f.Mode { + case 0, TransactionDisabled: + f.Mode = TransactionDisabled + default: + return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set") + } + } + if f.RunTx != nil { + switch f.Mode { + case 0, TransactionEnabled: + f.Mode = TransactionEnabled + default: + return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set") + } + } + // This is a defensive check. If the mode is still 0, it means we failed to infer the mode from + // the functions or return an error. This should never happen. + if f.Mode == 0 { + return errors.New("failed to infer transaction mode") } return nil } diff --git a/globals_test.go b/globals_test.go index 03febfa..f40f01f 100644 --- a/globals_test.go +++ b/globals_test.go @@ -1,113 +1,266 @@ -package goose_test +package goose import ( "context" "database/sql" "testing" - "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/check" ) -func TestGlobalRegister(t *testing.T) { - // Avoid polluting other tests and do not run in parallel. - t.Cleanup(func() { - goose.ResetGlobalMigrations() +func TestNewGoMigration(t *testing.T) { + t.Run("valid_both_nil", func(t *testing.T) { + m := NewGoMigration(1, nil, nil) + // roundtrip + check.Equal(t, m.Version, int64(1)) + check.Equal(t, m.Type, TypeGo) + check.Equal(t, m.Registered, true) + check.Equal(t, m.Next, int64(-1)) + check.Equal(t, m.Previous, int64(-1)) + check.Equal(t, m.Source, "") + check.Bool(t, m.UpFnNoTxContext == nil, true) + check.Bool(t, m.DownFnNoTxContext == nil, true) + check.Bool(t, m.UpFnContext == nil, true) + check.Bool(t, m.DownFnContext == nil, true) + check.Bool(t, m.UpFn == nil, true) + check.Bool(t, m.DownFn == nil, true) + check.Bool(t, m.UpFnNoTx == nil, true) + check.Bool(t, m.DownFnNoTx == nil, true) + check.Bool(t, m.goUp != nil, true) + check.Bool(t, m.goDown != nil, true) + check.Equal(t, m.goUp.Mode, TransactionEnabled) + check.Equal(t, m.goDown.Mode, TransactionEnabled) }) - fnNoTx := func(context.Context, *sql.DB) error { return nil } - fn := func(context.Context, *sql.Tx) error { return nil } + t.Run("all_set", func(t *testing.T) { + // This will eventually be an error when registering migrations. + m := NewGoMigration( + 1, + &GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }}, + &GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }}, + ) + // check only functions + check.Bool(t, m.UpFn != nil, true) + check.Bool(t, m.UpFnContext != nil, true) + check.Bool(t, m.UpFnNoTx != nil, true) + check.Bool(t, m.UpFnNoTxContext != nil, true) + check.Bool(t, m.DownFn != nil, true) + check.Bool(t, m.DownFnContext != nil, true) + check.Bool(t, m.DownFnNoTx != nil, true) + check.Bool(t, m.DownFnNoTxContext != nil, true) + }) +} + +func TestTransactionMode(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) + + runDB := func(context.Context, *sql.DB) error { return nil } + runTx := func(context.Context, *sql.Tx) error { return nil } + + err := SetGlobalMigrations( + NewGoMigration(1, &GoFunc{RunTx: runTx, RunDB: runDB}, nil), // cannot specify both + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB") + err = SetGlobalMigrations( + NewGoMigration(1, nil, &GoFunc{RunTx: runTx, RunDB: runDB}), // cannot specify both + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB") + err = SetGlobalMigrations( + NewGoMigration(1, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}, nil), // invalid explicit mode tx + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set") + err = SetGlobalMigrations( + NewGoMigration(1, nil, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}), // invalid explicit mode tx + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set") + err = SetGlobalMigrations( + NewGoMigration(1, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}, nil), // invalid explicit mode no-tx + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set") + err = SetGlobalMigrations( + NewGoMigration(1, nil, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}), // invalid explicit mode no-tx + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set") + + t.Run("default_mode", func(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) + + m := NewGoMigration(1, nil, nil) + err = SetGlobalMigrations(m) + check.NoError(t, err) + check.Number(t, len(registeredGoMigrations), 1) + registered := registeredGoMigrations[1] + check.Bool(t, registered.goUp != nil, true) + check.Bool(t, registered.goDown != nil, true) + check.Equal(t, registered.goUp.Mode, TransactionEnabled) + check.Equal(t, registered.goDown.Mode, TransactionEnabled) + + migration2 := NewGoMigration(2, nil, nil) + // reset so we can check the default is set + migration2.goUp.Mode, migration2.goDown.Mode = 0, 0 + err = SetGlobalMigrations(migration2) + check.NoError(t, err) + check.Number(t, len(registeredGoMigrations), 2) + registered = registeredGoMigrations[2] + check.Bool(t, registered.goUp != nil, true) + check.Bool(t, registered.goDown != nil, true) + check.Equal(t, registered.goUp.Mode, TransactionEnabled) + check.Equal(t, registered.goDown.Mode, TransactionEnabled) + }) + t.Run("unknown_mode", func(t *testing.T) { + m := NewGoMigration(1, nil, nil) + m.goUp.Mode, m.goDown.Mode = 3, 3 // reset to default + err := SetGlobalMigrations(m) + check.HasError(t, err) + check.Contains(t, err.Error(), "invalid mode: 3") + }) +} + +func TestLegacyFunctions(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) + + runDB := func(context.Context, *sql.DB) error { return nil } + runTx := func(context.Context, *sql.Tx) error { return nil } + + assertMigration := func(t *testing.T, m *Migration, version int64) { + t.Helper() + check.Equal(t, m.Version, version) + check.Equal(t, m.Type, TypeGo) + check.Equal(t, m.Registered, true) + check.Equal(t, m.Next, int64(-1)) + check.Equal(t, m.Previous, int64(-1)) + check.Equal(t, m.Source, "") + } + + t.Run("all_tx", func(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) + err := SetGlobalMigrations( + NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}), + ) + check.NoError(t, err) + check.Number(t, len(registeredGoMigrations), 1) + m := registeredGoMigrations[1] + assertMigration(t, m, 1) + // Legacy functions. + check.Bool(t, m.UpFnNoTxContext == nil, true) + check.Bool(t, m.DownFnNoTxContext == nil, true) + // Context-aware functions. + check.Bool(t, m.goUp == nil, false) + check.Bool(t, m.UpFnContext == nil, false) + check.Bool(t, m.goDown == nil, false) + check.Bool(t, m.DownFnContext == nil, false) + // Always nil + check.Bool(t, m.UpFn == nil, false) + check.Bool(t, m.DownFn == nil, false) + check.Bool(t, m.UpFnNoTx == nil, true) + check.Bool(t, m.DownFnNoTx == nil, true) + }) + t.Run("all_db", func(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) + err := SetGlobalMigrations( + NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}), + ) + check.NoError(t, err) + check.Number(t, len(registeredGoMigrations), 1) + m := registeredGoMigrations[2] + assertMigration(t, m, 2) + // Legacy functions. + check.Bool(t, m.UpFnNoTxContext == nil, false) + check.Bool(t, m.goUp == nil, false) + check.Bool(t, m.DownFnNoTxContext == nil, false) + check.Bool(t, m.goDown == nil, false) + // Context-aware functions. + check.Bool(t, m.UpFnContext == nil, true) + check.Bool(t, m.DownFnContext == nil, true) + // Always nil + check.Bool(t, m.UpFn == nil, true) + check.Bool(t, m.DownFn == nil, true) + check.Bool(t, m.UpFnNoTx == nil, false) + check.Bool(t, m.DownFnNoTx == nil, false) + }) +} + +func TestGlobalRegister(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) + + // runDB := func(context.Context, *sql.DB) error { return nil } + runTx := func(context.Context, *sql.Tx) error { return nil } // Success. - err := goose.SetGlobalMigrations( - []goose.Migration{}..., + err := SetGlobalMigrations([]Migration{}...) + check.NoError(t, err) + err = SetGlobalMigrations( + NewGoMigration(1, &GoFunc{RunTx: runTx}, nil), ) check.NoError(t, err) - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo, UpFnContext: fn}, - ) - check.NoError(t, err) - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo}, + // Try to register the same migration again. + err = SetGlobalMigrations( + NewGoMigration(1, &GoFunc{RunTx: runTx}, nil), ) check.HasError(t, err) check.Contains(t, err.Error(), "go migration with version 1 already registered") - err = goose.SetGlobalMigrations( - goose.Migration{ - Registered: true, - Version: 2, - Source: "00002_foo.sql", - Type: goose.TypeGo, - UpFnContext: func(context.Context, *sql.Tx) error { return nil }, - DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, - }, - ) - check.NoError(t, err) - // Reset. - { - goose.ResetGlobalMigrations() - } - // Failure. - err = goose.SetGlobalMigrations( - goose.Migration{}, - ) + err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must be registered") - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: type must be "go"`) - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, Type: goose.TypeSQL}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: type must be "go"`) - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 0, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero") - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: numeric component [1] in go migration does not match version in source "2_foo.sql"`) - // Legacy functions. - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn") - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn") - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx") - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx") - // Context-aware functions. - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, UpFnContext: fn, UpFnNoTxContext: fnNoTx, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext") - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, DownFnContext: fn, DownFnNoTxContext: fnNoTx, Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext") - // Source and version mismatch. - err = goose.SetGlobalMigrations( - goose.Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: goose.TypeGo}, - ) - check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`) + check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") +} + +func TestCheckMigration(t *testing.T) { + // Failures. + err := checkMigration(&Migration{}) + check.HasError(t, err) + check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") + err = checkMigration(&Migration{construct: true}) + check.HasError(t, err) + check.Contains(t, err.Error(), "must be registered") + err = checkMigration(&Migration{construct: true, Registered: true}) + check.HasError(t, err) + check.Contains(t, err.Error(), `type must be "go"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo}) + check.HasError(t, err) + check.Contains(t, err.Error(), "version must be greater than zero") + // Success. + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1}) + check.NoError(t, err) + // Failures. + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"}) + check.HasError(t, err) + check.Contains(t, err.Error(), `source must have .go extension: "foo"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"}) + check.HasError(t, err) + check.Contains(t, err.Error(), `no filename separator '_' found`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"}) + check.HasError(t, err) + check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"}) + check.HasError(t, err) + check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + UpFnContext: func(context.Context, *sql.Tx) error { return nil }, + UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + DownFnContext: func(context.Context, *sql.Tx) error { return nil }, + DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + UpFn: func(*sql.Tx) error { return nil }, + UpFnNoTx: func(*sql.DB) error { return nil }, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx") + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + DownFn: func(*sql.Tx) error { return nil }, + DownFnNoTx: func(*sql.DB) error { return nil }, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx") } diff --git a/internal/check/check.go b/internal/check/check.go index f5d1b6d..76dfac7 100644 --- a/internal/check/check.go +++ b/internal/check/check.go @@ -8,13 +8,6 @@ import ( "testing" ) -func NotNil(t *testing.T, v any) { - t.Helper() - if v == nil { - t.Fatal("unexpected nil value") - } -} - func NoError(t *testing.T, err error) { t.Helper() if err != nil { diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go index b88069c..4d12a0d 100644 --- a/internal/provider/run_test.go +++ b/internal/provider/run_test.go @@ -113,7 +113,7 @@ func TestProviderRun(t *testing.T) { break } check.NoError(t, err) - check.NotNil(t, res) + check.Bool(t, res != nil, true) check.Number(t, res.Source.Version, int64(counter)) } currentVersion, err := p.GetDBVersion(ctx) @@ -132,7 +132,7 @@ func TestProviderRun(t *testing.T) { break } check.NoError(t, err) - check.NotNil(t, res) + check.Bool(t, res != nil, true) check.Number(t, res.Source.Version, int64(maxVersion-counter+1)) } // Once everything is tested the version should match the highest testdata version @@ -632,12 +632,12 @@ func TestAllowMissing(t *testing.T) { // 4 upResult, err := p.UpByOne(ctx) check.NoError(t, err) - check.NotNil(t, upResult) + check.Bool(t, upResult != nil, true) check.Number(t, upResult.Source.Version, 4) // 6 upResult, err = p.UpByOne(ctx) check.NoError(t, err) - check.NotNil(t, upResult) + check.Bool(t, upResult != nil, true) check.Number(t, upResult.Source.Version, 6) count, err := getGooseVersionCount(db, provider.DefaultTablename) @@ -660,7 +660,7 @@ func TestAllowMissing(t *testing.T) { check.Number(t, currentVersion, wantDBVersion) downRes, err := p.Down(ctx) check.NoError(t, err) - check.NotNil(t, downRes) + check.Bool(t, downRes != nil, true) check.Number(t, downRes.Source.Version, wantResultVersion) } @@ -897,7 +897,7 @@ func TestLockModeAdvisorySession(t *testing.T) { return err } check.NoError(t, err) - check.NotNil(t, result) + check.Bool(t, result != nil, true) mu.Lock() applied = append(applied, result.Source.Version) mu.Unlock() @@ -913,7 +913,7 @@ func TestLockModeAdvisorySession(t *testing.T) { return err } check.NoError(t, err) - check.NotNil(t, result) + check.Bool(t, result != nil, true) mu.Lock() applied = append(applied, result.Source.Version) mu.Unlock() @@ -999,7 +999,7 @@ func TestLockModeAdvisorySession(t *testing.T) { return err } check.NoError(t, err) - check.NotNil(t, result) + check.Bool(t, result != nil, true) mu.Lock() applied = append(applied, result.Source.Version) mu.Unlock() @@ -1015,7 +1015,7 @@ func TestLockModeAdvisorySession(t *testing.T) { return err } check.NoError(t, err) - check.NotNil(t, result) + check.Bool(t, result != nil, true) mu.Lock() applied = append(applied, result.Source.Version) mu.Unlock() @@ -1127,7 +1127,7 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) { t.Helper() - check.NotNil(t, got) + check.Bool(t, got != nil, true) check.Equal(t, got.Source, source) check.Equal(t, got.Direction, direction) check.Equal(t, got.Empty, isEmpty) diff --git a/migrate.go b/migrate.go index 599810c..22769ff 100644 --- a/migrate.go +++ b/migrate.go @@ -8,7 +8,6 @@ import ( "io/fs" "math" "path" - "runtime" "sort" "strings" "time" @@ -125,115 +124,6 @@ func (ms Migrations) String() string { return str } -// GoMigration is a Go migration func that is run within a transaction. -type GoMigration func(tx *sql.Tx) error - -// GoMigrationContext is a Go migration func that is run within a transaction and receives a context. -type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error - -// GoMigrationNoTx is a Go migration func that is run outside a transaction. -type GoMigrationNoTx func(db *sql.DB) error - -// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context. -type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error - -// AddMigration adds Go migrations. -// -// Deprecated: Use AddMigrationContext. -func AddMigration(up, down GoMigration) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationContext(filename, withContext(up), withContext(down)) -} - -// AddMigrationContext adds Go migrations. -func AddMigrationContext(up, down GoMigrationContext) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationContext(filename, up, down) -} - -// AddNamedMigration adds named Go migrations. -// -// Deprecated: Use AddNamedMigrationContext. -func AddNamedMigration(filename string, up, down GoMigration) { - AddNamedMigrationContext(filename, withContext(up), withContext(down)) -} - -// AddNamedMigrationContext adds named Go migrations. -func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { - if err := register(filename, true, up, down, nil, nil); err != nil { - panic(err) - } -} - -// AddMigrationNoTx adds Go migrations that will be run outside transaction. -// -// Deprecated: Use AddNamedMigrationNoTxContext. -func AddMigrationNoTx(up, down GoMigrationNoTx) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) -} - -// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. -func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationNoTxContext(filename, up, down) -} - -// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. -// -// Deprecated: Use AddNamedMigrationNoTxContext. -func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { - AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) -} - -// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. -func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { - if err := register(filename, false, nil, nil, up, down); err != nil { - panic(err) - } -} - -func register( - filename string, - useTx bool, - up, down GoMigrationContext, - upNoTx, downNoTx GoMigrationNoTxContext, -) error { - // Sanity check caller did not mix tx and non-tx based functions. - if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { - return fmt.Errorf("cannot mix tx and non-tx based go migrations functions") - } - v, _ := NumericComponent(filename) - if existing, ok := registeredGoMigrations[v]; ok { - return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", - filename, - v, - existing.Source, - ) - } - // Add to global as a registered migration. - registeredGoMigrations[v] = &Migration{ - Version: v, - Next: -1, - Previous: -1, - Registered: true, - Source: filename, - UseTx: useTx, - UpFnContext: up, - DownFnContext: down, - UpFnNoTxContext: upNoTx, - DownFnNoTxContext: downNoTx, - // These are deprecated and will be removed in the future. - // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. - // Goose does not use these internally anymore and instead uses the context versions. - UpFn: withoutContext(up), - DownFn: withoutContext(down), - UpFnNoTx: withoutContext(upNoTx), - DownFnNoTx: withoutContext(downNoTx), - } - return nil -} - func collectMigrationsFS( fsys fs.FS, dirpath string, @@ -388,29 +278,6 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { return version, nil } -// withContext changes the signature of a function that receives one argument to receive a context and the argument. -func withContext[T any](fn func(T) error) func(context.Context, T) error { - if fn == nil { - return nil - } - - return func(ctx context.Context, t T) error { - return fn(t) - } -} - -// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument. -// When called the passed context is always context.Background(). -func withoutContext[T any](fn func(context.Context, T) error) func(T) error { - if fn == nil { - return nil - } - - return func(t T) error { - return fn(context.Background(), t) - } -} - // collectGoMigrations collects Go migrations from the filesystem and merges them with registered // migrations. // diff --git a/migration.go b/migration.go index d81e589..be23d03 100644 --- a/migration.go +++ b/migration.go @@ -13,36 +13,141 @@ import ( "github.com/pressly/goose/v3/internal/sqlparser" ) +// NewGoMigration creates a new Go migration. +// +// Both up and down functions may be nil, in which case the migration will be recorded in the +// versions table but no functions will be run. This is useful for recording (up) or deleting (down) +// a version without running any functions. See [GoFunc] for more details. +func NewGoMigration(version int64, up, down *GoFunc) Migration { + m := Migration{ + Type: TypeGo, + Registered: true, + Version: version, + Next: -1, Previous: -1, + goUp: up, + goDown: down, + construct: true, + } + // To maintain backwards compatibility, we set ALL legacy functions. In a future major version, + // we will remove these fields in favor of [GoFunc]. + // + // Note, this function does not do any validation. Validation is lazily done when the migration + // is registered. + if up != nil { + if up.RunDB != nil { + m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error + m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error + } + if up.RunTx != nil { + m.UseTx = true + m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error + m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error + } + } + if down != nil { + if down.RunDB != nil { + m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error + m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error + } + if down.RunTx != nil { + m.UseTx = true + m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error + m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error + } + } + if m.goUp == nil { + m.goUp = &GoFunc{Mode: TransactionEnabled} + } + if m.goDown == nil { + m.goDown = &GoFunc{Mode: TransactionEnabled} + } + return m +} + +// Migration struct represents either a SQL or Go migration. +// +// Avoid constructing migrations manually, use [NewGoMigration] function. +type Migration struct { + Type MigrationType + Version int64 + // Source is the path to the .sql script or .go file. It may be empty for Go migrations that + // have been registered globally and don't have a source file. + Source string + + UpFnContext, DownFnContext GoMigrationContext + UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext + // These fields are used internally by goose and users are not expected to set them. Instead, + // use [NewGoMigration] to create a new go migration. + construct bool + goUp, goDown *GoFunc + + // These fields will be removed in a future major version. They are here for backwards + // compatibility and are an implementation detail. + Registered bool + UseTx bool + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + + // We still save the non-context versions in the struct in case someone is using them. Goose + // does not use these internally anymore in favor of the context-aware versions. These fields + // will be removed in a future major version. + + UpFn GoMigration // Deprecated: use UpFnContext instead. + DownFn GoMigration // Deprecated: use DownFnContext instead. + UpFnNoTx GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead. + DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead. + + noVersioning bool +} + +// GoFunc represents a Go migration function. +type GoFunc struct { + // Exactly one of these must be set, or both must be nil. + RunTx func(ctx context.Context, tx *sql.Tx) error + // -- OR -- + RunDB func(ctx context.Context, db *sql.DB) error + + // Mode is the transaction mode for the migration. When one of the run functions is set, the + // mode will be inferred from the function and the field is ignored. Users do not need to set + // this field when supplying a run function. + // + // If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil + // functions is to record a version in the version table without invoking a Go migration + // function. + // + // The only time this field is required is if BOTH run functions are nil AND you want to + // override the default transaction mode. + Mode TransactionMode +} + +// TransactionMode represents the possible transaction modes for a migration. +type TransactionMode int + +const ( + TransactionEnabled TransactionMode = iota + 1 + TransactionDisabled +) + +func (m TransactionMode) String() string { + switch m { + case TransactionEnabled: + return "transaction_enabled" + case TransactionDisabled: + return "transaction_disabled" + default: + return fmt.Sprintf("unknown transaction mode (%d)", m) + } +} + // MigrationRecord struct. +// +// Deprecated: unused and will be removed in a future major version. type MigrationRecord struct { VersionID int64 TStamp time.Time IsApplied bool // was this a result of up() or down() } -// Migration struct represents either a SQL or Go migration. -type Migration struct { - Type MigrationType - Version int64 - Source string // path to .sql script or .go file - Registered bool - UpFnContext, DownFnContext GoMigrationContext - UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext - - // These fields will be removed in a future major version. They are here for backwards - // compatibility and are an implementation detail. - UseTx bool - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - - // We still save the non-context versions in the struct in case someone is using them. Goose - // does not use these internally anymore in favor of the context-aware versions. - UpFn, DownFn GoMigration - UpFnNoTx, DownFnNoTx GoMigrationNoTx - - noVersioning bool -} - func (m *Migration) String() string { return fmt.Sprint(m.Source) } diff --git a/register.go b/register.go new file mode 100644 index 0000000..89bd4c7 --- /dev/null +++ b/register.go @@ -0,0 +1,133 @@ +package goose + +import ( + "context" + "database/sql" + "fmt" + "runtime" +) + +// GoMigrationContext is a Go migration func that is run within a transaction and receives a +// context. +type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error + +// AddMigrationContext adds Go migrations. +func AddMigrationContext(up, down GoMigrationContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, up, down) +} + +// AddNamedMigrationContext adds named Go migrations. +func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { + if err := register( + filename, + true, + &GoFunc{RunTx: up}, + &GoFunc{RunTx: down}, + ); err != nil { + panic(err) + } +} + +// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a +// context. +type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error + +// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. +func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationNoTxContext(filename, up, down) +} + +// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. +func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { + if err := register( + filename, + false, + &GoFunc{RunDB: up}, + &GoFunc{RunDB: down}, + ); err != nil { + panic(err) + } +} + +func register(filename string, useTx bool, up, down *GoFunc) error { + v, _ := NumericComponent(filename) + if existing, ok := registeredGoMigrations[v]; ok { + return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", + filename, + v, + existing.Source, + ) + } + // Add to global as a registered migration. + m := NewGoMigration(v, up, down) + m.Source = filename + // We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but + // we know based on the register function what the user is requesting. + m.UseTx = useTx + registeredGoMigrations[v] = &m + return nil +} + +// withContext changes the signature of a function that receives one argument to receive a context +// and the argument. +func withContext[T any](fn func(T) error) func(context.Context, T) error { + if fn == nil { + return nil + } + return func(ctx context.Context, t T) error { + return fn(t) + } +} + +// withoutContext changes the signature of a function that receives a context and one argument to +// receive only the argument. When called the passed context is always context.Background(). +func withoutContext[T any](fn func(context.Context, T) error) func(T) error { + if fn == nil { + return nil + } + return func(t T) error { + return fn(context.Background(), t) + } +} + +// GoMigration is a Go migration func that is run within a transaction. +// +// Deprecated: Use GoMigrationContext. +type GoMigration func(tx *sql.Tx) error + +// GoMigrationNoTx is a Go migration func that is run outside a transaction. +// +// Deprecated: Use GoMigrationNoTxContext. +type GoMigrationNoTx func(db *sql.DB) error + +// AddMigration adds Go migrations. +// +// Deprecated: Use AddMigrationContext. +func AddMigration(up, down GoMigration) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, withContext(up), withContext(down)) +} + +// AddNamedMigration adds named Go migrations. +// +// Deprecated: Use AddNamedMigrationContext. +func AddNamedMigration(filename string, up, down GoMigration) { + AddNamedMigrationContext(filename, withContext(up), withContext(down)) +} + +// AddMigrationNoTx adds Go migrations that will be run outside transaction. +// +// Deprecated: Use AddMigrationNoTxContext. +func AddMigrationNoTx(up, down GoMigrationNoTx) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) +} + +// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. +// +// Deprecated: Use AddNamedMigrationNoTxContext. +func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { + AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) +} diff --git a/tests/gomigrations/success/gomigrations_success_test.go b/tests/gomigrations/success/gomigrations_success_test.go index 87bebf4..306efe2 100644 --- a/tests/gomigrations/success/gomigrations_success_test.go +++ b/tests/gomigrations/success/gomigrations_success_test.go @@ -1,27 +1,30 @@ -package gomigrations +package gomigrations_test import ( + "database/sql" "path/filepath" "testing" "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/testdb" _ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata" + _ "modernc.org/sqlite" ) func TestGoMigrationByOne(t *testing.T) { - db, cleanup, err := testdb.NewPostgres() - check.NoError(t, err) - t.Cleanup(cleanup) + t.Parallel() + check.NoError(t, goose.SetDialect("sqlite3")) + db, err := sql.Open("sqlite", ":memory:") + check.NoError(t, err) dir := "testdata" files, err := filepath.Glob(dir + "/*.go") check.NoError(t, err) upByOne := func(t *testing.T) int64 { err = goose.UpByOne(db, dir) + t.Logf("err: %v %s", err, dir) check.NoError(t, err) version, err := goose.GetDBVersion(db) check.NoError(t, err) @@ -42,6 +45,21 @@ func TestGoMigrationByOne(t *testing.T) { check.NoError(t, err) check.Number(t, version, len(files)) + tables, err := ListTables(db) + check.NoError(t, err) + check.Equal(t, tables, []string{ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "golf", + "goose_db_version", + "hotel", + "sqlite_sequence", + }) + // Migrate all files down-by-one. for i := len(files) - 1; i >= 0; i-- { check.Number(t, downByOne(t), i) @@ -49,4 +67,31 @@ func TestGoMigrationByOne(t *testing.T) { version, err = goose.GetDBVersion(db) check.NoError(t, err) check.Number(t, version, 0) + + tables, err = ListTables(db) + check.NoError(t, err) + check.Equal(t, tables, []string{ + "goose_db_version", + "sqlite_sequence", + }) +} + +func ListTables(db *sql.DB) ([]string, error) { + rows, err := db.Query(`SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`) + if err != nil { + return nil, err + } + defer rows.Close() + var tables []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + tables = append(tables, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return tables, nil } diff --git a/tests/gomigrations/success/testdata/001_up_down.go b/tests/gomigrations/success/testdata/001_up_down.go index 9fed61c..b3c65e9 100644 --- a/tests/gomigrations/success/testdata/001_up_down.go +++ b/tests/gomigrations/success/testdata/001_up_down.go @@ -1,9 +1,12 @@ package gomigrations import ( + "context" "database/sql" + "fmt" "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" ) func init() { @@ -11,13 +14,19 @@ func init() { } func up001(tx *sql.Tx) error { - q := "CREATE TABLE foo (id INT, subid INT, name TEXT)" - _, err := tx.Exec(q) - return err + return createTable(tx, "alpha") } func down001(tx *sql.Tx) error { - q := "DROP TABLE IF EXISTS foo" - _, err := tx.Exec(q) + return dropTable(tx, "alpha") +} + +func createTable(db database.DBTxConn, name string) error { + _, err := db.ExecContext(context.Background(), fmt.Sprintf("CREATE TABLE %s (id INTEGER)", name)) + return err +} + +func dropTable(db database.DBTxConn, name string) error { + _, err := db.ExecContext(context.Background(), fmt.Sprintf("DROP TABLE %s", name)) return err } diff --git a/tests/gomigrations/success/testdata/002_up_only.go b/tests/gomigrations/success/testdata/002_up_only.go index 6ece192..e5aab5a 100644 --- a/tests/gomigrations/success/testdata/002_up_only.go +++ b/tests/gomigrations/success/testdata/002_up_only.go @@ -11,7 +11,5 @@ func init() { } func up002(tx *sql.Tx) error { - q := "INSERT INTO foo VALUES (1, 1, 'Alice')" - _, err := tx.Exec(q) - return err + return createTable(tx, "bravo") } diff --git a/tests/gomigrations/success/testdata/003_down_only.go b/tests/gomigrations/success/testdata/003_down_only.go index ff39f5f..c9d062b 100644 --- a/tests/gomigrations/success/testdata/003_down_only.go +++ b/tests/gomigrations/success/testdata/003_down_only.go @@ -11,7 +11,5 @@ func init() { } func down003(tx *sql.Tx) error { - q := "TRUNCATE TABLE foo" - _, err := tx.Exec(q) - return err + return dropTable(tx, "bravo") } diff --git a/tests/gomigrations/success/testdata/005_up_down_no_tx.go b/tests/gomigrations/success/testdata/005_up_down_no_tx.go index 7a6838d..1ef5a57 100644 --- a/tests/gomigrations/success/testdata/005_up_down_no_tx.go +++ b/tests/gomigrations/success/testdata/005_up_down_no_tx.go @@ -11,13 +11,9 @@ func init() { } func up005(db *sql.DB) error { - q := "CREATE TABLE users (id INT, email TEXT)" - _, err := db.Exec(q) - return err + return createTable(db, "charlie") } func down005(db *sql.DB) error { - q := "DROP TABLE IF EXISTS users" - _, err := db.Exec(q) - return err + return dropTable(db, "charlie") } diff --git a/tests/gomigrations/success/testdata/006_up_only_no_tx.go b/tests/gomigrations/success/testdata/006_up_only_no_tx.go index 26aa88c..2aa770c 100644 --- a/tests/gomigrations/success/testdata/006_up_only_no_tx.go +++ b/tests/gomigrations/success/testdata/006_up_only_no_tx.go @@ -11,7 +11,5 @@ func init() { } func up006(db *sql.DB) error { - q := "INSERT INTO users VALUES (1, 'admin@example.com')" - _, err := db.Exec(q) - return err + return createTable(db, "delta") } diff --git a/tests/gomigrations/success/testdata/007_down_only_no_tx.go b/tests/gomigrations/success/testdata/007_down_only_no_tx.go index 318b02e..86edd41 100644 --- a/tests/gomigrations/success/testdata/007_down_only_no_tx.go +++ b/tests/gomigrations/success/testdata/007_down_only_no_tx.go @@ -11,7 +11,5 @@ func init() { } func down007(db *sql.DB) error { - q := "TRUNCATE TABLE users" - _, err := db.Exec(q) - return err + return dropTable(db, "delta") } diff --git a/tests/gomigrations/success/testdata/008_empty_no_tx.go b/tests/gomigrations/success/testdata/008_empty_no_tx.go index 5efb376..76aaedf 100644 --- a/tests/gomigrations/success/testdata/008_empty_no_tx.go +++ b/tests/gomigrations/success/testdata/008_empty_no_tx.go @@ -5,5 +5,5 @@ import ( ) func init() { - goose.AddMigration(nil, nil) + goose.AddMigrationNoTx(nil, nil) } diff --git a/tests/gomigrations/success/testdata/009_up_down_ctx.go b/tests/gomigrations/success/testdata/009_up_down_ctx.go new file mode 100644 index 0000000..09ce310 --- /dev/null +++ b/tests/gomigrations/success/testdata/009_up_down_ctx.go @@ -0,0 +1,20 @@ +package gomigrations + +import ( + "context" + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(up009, down009) +} + +func up009(ctx context.Context, tx *sql.Tx) error { + return createTable(tx, "echo") +} + +func down009(ctx context.Context, tx *sql.Tx) error { + return dropTable(tx, "echo") +} diff --git a/tests/gomigrations/success/testdata/010_up_only_ctx.go b/tests/gomigrations/success/testdata/010_up_only_ctx.go new file mode 100644 index 0000000..0439530 --- /dev/null +++ b/tests/gomigrations/success/testdata/010_up_only_ctx.go @@ -0,0 +1,16 @@ +package gomigrations + +import ( + "context" + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(up010, nil) +} + +func up010(ctx context.Context, tx *sql.Tx) error { + return createTable(tx, "foxtrot") +} diff --git a/tests/gomigrations/success/testdata/011_down_only_ctx.go b/tests/gomigrations/success/testdata/011_down_only_ctx.go new file mode 100644 index 0000000..c1f1fc7 --- /dev/null +++ b/tests/gomigrations/success/testdata/011_down_only_ctx.go @@ -0,0 +1,16 @@ +package gomigrations + +import ( + "context" + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(nil, down011) +} + +func down011(ctx context.Context, tx *sql.Tx) error { + return dropTable(tx, "foxtrot") +} diff --git a/tests/gomigrations/success/testdata/012_empty_ctx.go b/tests/gomigrations/success/testdata/012_empty_ctx.go new file mode 100644 index 0000000..a8df353 --- /dev/null +++ b/tests/gomigrations/success/testdata/012_empty_ctx.go @@ -0,0 +1,9 @@ +package gomigrations + +import ( + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(nil, nil) +} diff --git a/tests/gomigrations/success/testdata/013_up_down_no_tx_ctx.go b/tests/gomigrations/success/testdata/013_up_down_no_tx_ctx.go new file mode 100644 index 0000000..da3a27c --- /dev/null +++ b/tests/gomigrations/success/testdata/013_up_down_no_tx_ctx.go @@ -0,0 +1,20 @@ +package gomigrations + +import ( + "context" + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTxContext(up013, down013) +} + +func up013(ctx context.Context, db *sql.DB) error { + return createTable(db, "golf") +} + +func down013(ctx context.Context, db *sql.DB) error { + return dropTable(db, "golf") +} diff --git a/tests/gomigrations/success/testdata/014_up_only_no_tx_ctx.go b/tests/gomigrations/success/testdata/014_up_only_no_tx_ctx.go new file mode 100644 index 0000000..a32de86 --- /dev/null +++ b/tests/gomigrations/success/testdata/014_up_only_no_tx_ctx.go @@ -0,0 +1,16 @@ +package gomigrations + +import ( + "context" + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTxContext(up014, nil) +} + +func up014(ctx context.Context, db *sql.DB) error { + return createTable(db, "hotel") +} diff --git a/tests/gomigrations/success/testdata/015_down_only_no_tx_ctx.go b/tests/gomigrations/success/testdata/015_down_only_no_tx_ctx.go new file mode 100644 index 0000000..18b8c2d --- /dev/null +++ b/tests/gomigrations/success/testdata/015_down_only_no_tx_ctx.go @@ -0,0 +1,16 @@ +package gomigrations + +import ( + "context" + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTxContext(nil, down015) +} + +func down015(ctx context.Context, db *sql.DB) error { + return dropTable(db, "hotel") +} diff --git a/tests/gomigrations/success/testdata/016_empty_no_tx_ctx.go b/tests/gomigrations/success/testdata/016_empty_no_tx_ctx.go new file mode 100644 index 0000000..e97c86b --- /dev/null +++ b/tests/gomigrations/success/testdata/016_empty_no_tx_ctx.go @@ -0,0 +1,9 @@ +package gomigrations + +import ( + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTxContext(nil, nil) +}