feat: Add NewGoMigration constructor (#631)

This commit is contained in:
Michael Fridman 2023-11-03 22:11:27 -04:00 committed by GitHub
parent 432a6ac0f8
commit 6d0de39c50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 786 additions and 328 deletions

View File

@ -3,44 +3,42 @@ package goose
import (
"errors"
"fmt"
"path/filepath"
)
var (
registeredGoMigrations = make(map[int64]*Migration)
)
// ResetGlobalMigrations resets the global go migrations registry.
// ResetGlobalMigrations resets the global Go migrations registry.
//
// Not safe for concurrent use.
func ResetGlobalMigrations() {
registeredGoMigrations = make(map[int64]*Migration)
}
// SetGlobalMigrations registers go migrations globally. It returns an error if a migration with the
// same version has already been registered.
//
// Source may be empty, but if it is set, it must be a path with a numeric component that matches
// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx.
// SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the
// same version has already been registered. Go migrations must be constructed using the
// [NewGoMigration] function.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, m := range migrations {
// make a copy of the migration so we can modify it without affecting the original.
if err := validGoMigration(&m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
for _, migration := range migrations {
m := &migration
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user.
registeredGoMigrations[m.Version] = &m
if err := checkMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
registeredGoMigrations[m.Version] = m
}
return nil
}
func validGoMigration(m *Migration) error {
if m == nil {
return errors.New("must not be nil")
func checkMigration(m *Migration) error {
if !m.construct {
return errors.New("must use NewGoMigration to construct migrations")
}
if !m.Registered {
return errors.New("must be registered")
@ -52,36 +50,81 @@ func validGoMigration(m *Migration) error {
return errors.New("version must be greater than zero")
}
if m.Source != "" {
if filepath.Ext(m.Source) != ".go" {
return fmt.Errorf("source must have .go extension: %q", m.Source)
}
// If the source is set, expect it to be a path with a numeric component that matches the
// version. This field is not intended to be used for descriptive purposes.
version, err := NumericComponent(m.Source)
if err != nil {
return err
return fmt.Errorf("invalid source: %w", err)
}
if version != m.Version {
return fmt.Errorf("numeric component [%d] in go migration does not match version in source %q", m.Version, m.Source)
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
}
}
// It's valid for all of these funcs to be nil. Which means version the go migration but do not
// run anything.
if err := setGoFunc(m.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
}
if err := setGoFunc(m.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.UpFn != nil && m.UpFnNoTx != nil {
return errors.New("must specify exactly one of UpFn or UpFnNoTx")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
// Do not allow legacy functions to be set.
if m.UpFn != nil {
return errors.New("must not specify UpFn")
}
if m.DownFn != nil {
return errors.New("must not specify DownFn")
}
if m.UpFnNoTx != nil {
return errors.New("must not specify UpFnNoTx")
}
if m.DownFnNoTx != nil {
return errors.New("must not specify DownFnNoTx")
if m.DownFn != nil && m.DownFnNoTx != nil {
return errors.New("must specify exactly one of DownFn or DownFnNoTx")
}
return nil
}
func setGoFunc(f *GoFunc) error {
if f == nil {
f = &GoFunc{Mode: TransactionEnabled}
return nil
}
if f.RunTx != nil && f.RunDB != nil {
return errors.New("must specify exactly one of RunTx or RunDB")
}
if f.RunTx == nil && f.RunDB == nil {
switch f.Mode {
case 0:
// Default to TransactionEnabled ONLY if mode is not set explicitly.
f.Mode = TransactionEnabled
case TransactionEnabled, TransactionDisabled:
// No functions but mode is set. This is not an error. It means the user wants to record
// a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
return nil
}
if f.RunDB != nil {
switch f.Mode {
case 0, TransactionDisabled:
f.Mode = TransactionDisabled
default:
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
}
if f.RunTx != nil {
switch f.Mode {
case 0, TransactionEnabled:
f.Mode = TransactionEnabled
default:
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
}
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
// the functions or return an error. This should never happen.
if f.Mode == 0 {
return errors.New("failed to infer transaction mode")
}
return nil
}

View File

@ -1,113 +1,266 @@
package goose_test
package goose
import (
"context"
"database/sql"
"testing"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
)
func TestGlobalRegister(t *testing.T) {
// Avoid polluting other tests and do not run in parallel.
t.Cleanup(func() {
goose.ResetGlobalMigrations()
func TestNewGoMigration(t *testing.T) {
t.Run("valid_both_nil", func(t *testing.T) {
m := NewGoMigration(1, nil, nil)
// roundtrip
check.Equal(t, m.Version, int64(1))
check.Equal(t, m.Type, TypeGo)
check.Equal(t, m.Registered, true)
check.Equal(t, m.Next, int64(-1))
check.Equal(t, m.Previous, int64(-1))
check.Equal(t, m.Source, "")
check.Bool(t, m.UpFnNoTxContext == nil, true)
check.Bool(t, m.DownFnNoTxContext == nil, true)
check.Bool(t, m.UpFnContext == nil, true)
check.Bool(t, m.DownFnContext == nil, true)
check.Bool(t, m.UpFn == nil, true)
check.Bool(t, m.DownFn == nil, true)
check.Bool(t, m.UpFnNoTx == nil, true)
check.Bool(t, m.DownFnNoTx == nil, true)
check.Bool(t, m.goUp != nil, true)
check.Bool(t, m.goDown != nil, true)
check.Equal(t, m.goUp.Mode, TransactionEnabled)
check.Equal(t, m.goDown.Mode, TransactionEnabled)
})
fnNoTx := func(context.Context, *sql.DB) error { return nil }
fn := func(context.Context, *sql.Tx) error { return nil }
t.Run("all_set", func(t *testing.T) {
// This will eventually be an error when registering migrations.
m := NewGoMigration(
1,
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
)
// check only functions
check.Bool(t, m.UpFn != nil, true)
check.Bool(t, m.UpFnContext != nil, true)
check.Bool(t, m.UpFnNoTx != nil, true)
check.Bool(t, m.UpFnNoTxContext != nil, true)
check.Bool(t, m.DownFn != nil, true)
check.Bool(t, m.DownFnContext != nil, true)
check.Bool(t, m.DownFnNoTx != nil, true)
check.Bool(t, m.DownFnNoTxContext != nil, true)
})
}
func TestTransactionMode(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
runDB := func(context.Context, *sql.DB) error { return nil }
runTx := func(context.Context, *sql.Tx) error { return nil }
err := SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx, RunDB: runDB}, nil), // cannot specify both
)
check.HasError(t, err)
check.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB")
err = SetGlobalMigrations(
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, RunDB: runDB}), // cannot specify both
)
check.HasError(t, err)
check.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB")
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}, nil), // invalid explicit mode tx
)
check.HasError(t, err)
check.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set")
err = SetGlobalMigrations(
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}), // invalid explicit mode tx
)
check.HasError(t, err)
check.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set")
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}, nil), // invalid explicit mode no-tx
)
check.HasError(t, err)
check.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set")
err = SetGlobalMigrations(
NewGoMigration(1, nil, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}), // invalid explicit mode no-tx
)
check.HasError(t, err)
check.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set")
t.Run("default_mode", func(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
m := NewGoMigration(1, nil, nil)
err = SetGlobalMigrations(m)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 1)
registered := registeredGoMigrations[1]
check.Bool(t, registered.goUp != nil, true)
check.Bool(t, registered.goDown != nil, true)
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
migration2 := NewGoMigration(2, nil, nil)
// reset so we can check the default is set
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
err = SetGlobalMigrations(migration2)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 2)
registered = registeredGoMigrations[2]
check.Bool(t, registered.goUp != nil, true)
check.Bool(t, registered.goDown != nil, true)
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
})
t.Run("unknown_mode", func(t *testing.T) {
m := NewGoMigration(1, nil, nil)
m.goUp.Mode, m.goDown.Mode = 3, 3 // reset to default
err := SetGlobalMigrations(m)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid mode: 3")
})
}
func TestLegacyFunctions(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
runDB := func(context.Context, *sql.DB) error { return nil }
runTx := func(context.Context, *sql.Tx) error { return nil }
assertMigration := func(t *testing.T, m *Migration, version int64) {
t.Helper()
check.Equal(t, m.Version, version)
check.Equal(t, m.Type, TypeGo)
check.Equal(t, m.Registered, true)
check.Equal(t, m.Next, int64(-1))
check.Equal(t, m.Previous, int64(-1))
check.Equal(t, m.Source, "")
}
t.Run("all_tx", func(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
err := SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}),
)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 1)
m := registeredGoMigrations[1]
assertMigration(t, m, 1)
// Legacy functions.
check.Bool(t, m.UpFnNoTxContext == nil, true)
check.Bool(t, m.DownFnNoTxContext == nil, true)
// Context-aware functions.
check.Bool(t, m.goUp == nil, false)
check.Bool(t, m.UpFnContext == nil, false)
check.Bool(t, m.goDown == nil, false)
check.Bool(t, m.DownFnContext == nil, false)
// Always nil
check.Bool(t, m.UpFn == nil, false)
check.Bool(t, m.DownFn == nil, false)
check.Bool(t, m.UpFnNoTx == nil, true)
check.Bool(t, m.DownFnNoTx == nil, true)
})
t.Run("all_db", func(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
err := SetGlobalMigrations(
NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}),
)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 1)
m := registeredGoMigrations[2]
assertMigration(t, m, 2)
// Legacy functions.
check.Bool(t, m.UpFnNoTxContext == nil, false)
check.Bool(t, m.goUp == nil, false)
check.Bool(t, m.DownFnNoTxContext == nil, false)
check.Bool(t, m.goDown == nil, false)
// Context-aware functions.
check.Bool(t, m.UpFnContext == nil, true)
check.Bool(t, m.DownFnContext == nil, true)
// Always nil
check.Bool(t, m.UpFn == nil, true)
check.Bool(t, m.DownFn == nil, true)
check.Bool(t, m.UpFnNoTx == nil, false)
check.Bool(t, m.DownFnNoTx == nil, false)
})
}
func TestGlobalRegister(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
// runDB := func(context.Context, *sql.DB) error { return nil }
runTx := func(context.Context, *sql.Tx) error { return nil }
// Success.
err := goose.SetGlobalMigrations(
[]goose.Migration{}...,
err := SetGlobalMigrations([]Migration{}...)
check.NoError(t, err)
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
)
check.NoError(t, err)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo, UpFnContext: fn},
)
check.NoError(t, err)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeGo},
// Try to register the same migration again.
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = goose.SetGlobalMigrations(
goose.Migration{
Registered: true,
Version: 2,
Source: "00002_foo.sql",
Type: goose.TypeGo,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
},
)
check.NoError(t, err)
// Reset.
{
goose.ResetGlobalMigrations()
}
// Failure.
err = goose.SetGlobalMigrations(
goose.Migration{},
)
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must be registered")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Type: goose.TypeSQL},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 0, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: numeric component [1] in go migration does not match version in source "2_foo.sql"`)
// Legacy functions.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx")
// Context-aware functions.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, UpFnContext: fn, UpFnNoTxContext: fnNoTx, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext")
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, DownFnContext: fn, DownFnNoTxContext: fnNoTx, Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext")
// Source and version mismatch.
err = goose.SetGlobalMigrations(
goose.Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: goose.TypeGo},
)
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
}
func TestCheckMigration(t *testing.T) {
// Failures.
err := checkMigration(&Migration{})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
err = checkMigration(&Migration{construct: true})
check.HasError(t, err)
check.Contains(t, err.Error(), "must be registered")
err = checkMigration(&Migration{construct: true, Registered: true})
check.HasError(t, err)
check.Contains(t, err.Error(), `type must be "go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "version must be greater than zero")
// Success.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
check.NoError(t, err)
// Failures.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `no filename separator '_' found`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFn: func(*sql.Tx) error { return nil },
UpFnNoTx: func(*sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFn: func(*sql.Tx) error { return nil },
DownFnNoTx: func(*sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
}

View File

@ -8,13 +8,6 @@ import (
"testing"
)
func NotNil(t *testing.T, v any) {
t.Helper()
if v == nil {
t.Fatal("unexpected nil value")
}
}
func NoError(t *testing.T, err error) {
t.Helper()
if err != nil {

View File

@ -113,7 +113,7 @@ func TestProviderRun(t *testing.T) {
break
}
check.NoError(t, err)
check.NotNil(t, res)
check.Bool(t, res != nil, true)
check.Number(t, res.Source.Version, int64(counter))
}
currentVersion, err := p.GetDBVersion(ctx)
@ -132,7 +132,7 @@ func TestProviderRun(t *testing.T) {
break
}
check.NoError(t, err)
check.NotNil(t, res)
check.Bool(t, res != nil, true)
check.Number(t, res.Source.Version, int64(maxVersion-counter+1))
}
// Once everything is tested the version should match the highest testdata version
@ -632,12 +632,12 @@ func TestAllowMissing(t *testing.T) {
// 4
upResult, err := p.UpByOne(ctx)
check.NoError(t, err)
check.NotNil(t, upResult)
check.Bool(t, upResult != nil, true)
check.Number(t, upResult.Source.Version, 4)
// 6
upResult, err = p.UpByOne(ctx)
check.NoError(t, err)
check.NotNil(t, upResult)
check.Bool(t, upResult != nil, true)
check.Number(t, upResult.Source.Version, 6)
count, err := getGooseVersionCount(db, provider.DefaultTablename)
@ -660,7 +660,7 @@ func TestAllowMissing(t *testing.T) {
check.Number(t, currentVersion, wantDBVersion)
downRes, err := p.Down(ctx)
check.NoError(t, err)
check.NotNil(t, downRes)
check.Bool(t, downRes != nil, true)
check.Number(t, downRes.Source.Version, wantResultVersion)
}
@ -897,7 +897,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err
}
check.NoError(t, err)
check.NotNil(t, result)
check.Bool(t, result != nil, true)
mu.Lock()
applied = append(applied, result.Source.Version)
mu.Unlock()
@ -913,7 +913,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err
}
check.NoError(t, err)
check.NotNil(t, result)
check.Bool(t, result != nil, true)
mu.Lock()
applied = append(applied, result.Source.Version)
mu.Unlock()
@ -999,7 +999,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err
}
check.NoError(t, err)
check.NotNil(t, result)
check.Bool(t, result != nil, true)
mu.Lock()
applied = append(applied, result.Source.Version)
mu.Unlock()
@ -1015,7 +1015,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
return err
}
check.NoError(t, err)
check.NotNil(t, result)
check.Bool(t, result != nil, true)
mu.Lock()
applied = append(applied, result.Source.Version)
mu.Unlock()
@ -1127,7 +1127,7 @@ func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.St
func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string, isEmpty bool) {
t.Helper()
check.NotNil(t, got)
check.Bool(t, got != nil, true)
check.Equal(t, got.Source, source)
check.Equal(t, got.Direction, direction)
check.Equal(t, got.Empty, isEmpty)

View File

@ -8,7 +8,6 @@ import (
"io/fs"
"math"
"path"
"runtime"
"sort"
"strings"
"time"
@ -125,115 +124,6 @@ func (ms Migrations) String() string {
return str
}
// GoMigration is a Go migration func that is run within a transaction.
type GoMigration func(tx *sql.Tx) error
// GoMigrationContext is a Go migration func that is run within a transaction and receives a context.
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error
// GoMigrationNoTx is a Go migration func that is run outside a transaction.
type GoMigrationNoTx func(db *sql.DB) error
// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context.
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error
// AddMigration adds Go migrations.
//
// Deprecated: Use AddMigrationContext.
func AddMigration(up, down GoMigration) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}
// AddMigrationContext adds Go migrations.
func AddMigrationContext(up, down GoMigrationContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, up, down)
}
// AddNamedMigration adds named Go migrations.
//
// Deprecated: Use AddNamedMigrationContext.
func AddNamedMigration(filename string, up, down GoMigration) {
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}
// AddNamedMigrationContext adds named Go migrations.
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
if err := register(filename, true, up, down, nil, nil); err != nil {
panic(err)
}
}
// AddMigrationNoTx adds Go migrations that will be run outside transaction.
//
// Deprecated: Use AddNamedMigrationNoTxContext.
func AddMigrationNoTx(up, down GoMigrationNoTx) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
}
// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTxContext(filename, up, down)
}
// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
//
// Deprecated: Use AddNamedMigrationNoTxContext.
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
}
// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
if err := register(filename, false, nil, nil, up, down); err != nil {
panic(err)
}
}
func register(
filename string,
useTx bool,
up, down GoMigrationContext,
upNoTx, downNoTx GoMigrationNoTxContext,
) error {
// Sanity check caller did not mix tx and non-tx based functions.
if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) {
return fmt.Errorf("cannot mix tx and non-tx based go migrations functions")
}
v, _ := NumericComponent(filename)
if existing, ok := registeredGoMigrations[v]; ok {
return fmt.Errorf("failed to add migration %q: version %d conflicts with %q",
filename,
v,
existing.Source,
)
}
// Add to global as a registered migration.
registeredGoMigrations[v] = &Migration{
Version: v,
Next: -1,
Previous: -1,
Registered: true,
Source: filename,
UseTx: useTx,
UpFnContext: up,
DownFnContext: down,
UpFnNoTxContext: upNoTx,
DownFnNoTxContext: downNoTx,
// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
UpFn: withoutContext(up),
DownFn: withoutContext(down),
UpFnNoTx: withoutContext(upNoTx),
DownFnNoTx: withoutContext(downNoTx),
}
return nil
}
func collectMigrationsFS(
fsys fs.FS,
dirpath string,
@ -388,29 +278,6 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
return version, nil
}
// withContext changes the signature of a function that receives one argument to receive a context and the argument.
func withContext[T any](fn func(T) error) func(context.Context, T) error {
if fn == nil {
return nil
}
return func(ctx context.Context, t T) error {
return fn(t)
}
}
// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument.
// When called the passed context is always context.Background().
func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
if fn == nil {
return nil
}
return func(t T) error {
return fn(context.Background(), t)
}
}
// collectGoMigrations collects Go migrations from the filesystem and merges them with registered
// migrations.
//

View File

@ -13,36 +13,141 @@ import (
"github.com/pressly/goose/v3/internal/sqlparser"
)
// NewGoMigration creates a new Go migration.
//
// Both up and down functions may be nil, in which case the migration will be recorded in the
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
// a version without running any functions. See [GoFunc] for more details.
func NewGoMigration(version int64, up, down *GoFunc) Migration {
m := Migration{
Type: TypeGo,
Registered: true,
Version: version,
Next: -1, Previous: -1,
goUp: up,
goDown: down,
construct: true,
}
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
// we will remove these fields in favor of [GoFunc].
//
// Note, this function does not do any validation. Validation is lazily done when the migration
// is registered.
if up != nil {
if up.RunDB != nil {
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
}
if up.RunTx != nil {
m.UseTx = true
m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error
m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error
}
}
if down != nil {
if down.RunDB != nil {
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
}
if down.RunTx != nil {
m.UseTx = true
m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
}
}
if m.goUp == nil {
m.goUp = &GoFunc{Mode: TransactionEnabled}
}
if m.goDown == nil {
m.goDown = &GoFunc{Mode: TransactionEnabled}
}
return m
}
// Migration struct represents either a SQL or Go migration.
//
// Avoid constructing migrations manually, use [NewGoMigration] function.
type Migration struct {
Type MigrationType
Version int64
// Source is the path to the .sql script or .go file. It may be empty for Go migrations that
// have been registered globally and don't have a source file.
Source string
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
// These fields are used internally by goose and users are not expected to set them. Instead,
// use [NewGoMigration] to create a new go migration.
construct bool
goUp, goDown *GoFunc
// These fields will be removed in a future major version. They are here for backwards
// compatibility and are an implementation detail.
Registered bool
UseTx bool
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
// We still save the non-context versions in the struct in case someone is using them. Goose
// does not use these internally anymore in favor of the context-aware versions. These fields
// will be removed in a future major version.
UpFn GoMigration // Deprecated: use UpFnContext instead.
DownFn GoMigration // Deprecated: use DownFnContext instead.
UpFnNoTx GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead.
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
noVersioning bool
}
// GoFunc represents a Go migration function.
type GoFunc struct {
// Exactly one of these must be set, or both must be nil.
RunTx func(ctx context.Context, tx *sql.Tx) error
// -- OR --
RunDB func(ctx context.Context, db *sql.DB) error
// Mode is the transaction mode for the migration. When one of the run functions is set, the
// mode will be inferred from the function and the field is ignored. Users do not need to set
// this field when supplying a run function.
//
// If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil
// functions is to record a version in the version table without invoking a Go migration
// function.
//
// The only time this field is required is if BOTH run functions are nil AND you want to
// override the default transaction mode.
Mode TransactionMode
}
// TransactionMode represents the possible transaction modes for a migration.
type TransactionMode int
const (
TransactionEnabled TransactionMode = iota + 1
TransactionDisabled
)
func (m TransactionMode) String() string {
switch m {
case TransactionEnabled:
return "transaction_enabled"
case TransactionDisabled:
return "transaction_disabled"
default:
return fmt.Sprintf("unknown transaction mode (%d)", m)
}
}
// MigrationRecord struct.
//
// Deprecated: unused and will be removed in a future major version.
type MigrationRecord struct {
VersionID int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
// Migration struct represents either a SQL or Go migration.
type Migration struct {
Type MigrationType
Version int64
Source string // path to .sql script or .go file
Registered bool
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
// These fields will be removed in a future major version. They are here for backwards
// compatibility and are an implementation detail.
UseTx bool
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
// We still save the non-context versions in the struct in case someone is using them. Goose
// does not use these internally anymore in favor of the context-aware versions.
UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx
noVersioning bool
}
func (m *Migration) String() string {
return fmt.Sprint(m.Source)
}

133
register.go Normal file
View File

@ -0,0 +1,133 @@
package goose
import (
"context"
"database/sql"
"fmt"
"runtime"
)
// GoMigrationContext is a Go migration func that is run within a transaction and receives a
// context.
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error
// AddMigrationContext adds Go migrations.
func AddMigrationContext(up, down GoMigrationContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, up, down)
}
// AddNamedMigrationContext adds named Go migrations.
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
if err := register(
filename,
true,
&GoFunc{RunTx: up},
&GoFunc{RunTx: down},
); err != nil {
panic(err)
}
}
// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a
// context.
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error
// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTxContext(filename, up, down)
}
// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
if err := register(
filename,
false,
&GoFunc{RunDB: up},
&GoFunc{RunDB: down},
); err != nil {
panic(err)
}
}
func register(filename string, useTx bool, up, down *GoFunc) error {
v, _ := NumericComponent(filename)
if existing, ok := registeredGoMigrations[v]; ok {
return fmt.Errorf("failed to add migration %q: version %d conflicts with %q",
filename,
v,
existing.Source,
)
}
// Add to global as a registered migration.
m := NewGoMigration(v, up, down)
m.Source = filename
// We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but
// we know based on the register function what the user is requesting.
m.UseTx = useTx
registeredGoMigrations[v] = &m
return nil
}
// withContext changes the signature of a function that receives one argument to receive a context
// and the argument.
func withContext[T any](fn func(T) error) func(context.Context, T) error {
if fn == nil {
return nil
}
return func(ctx context.Context, t T) error {
return fn(t)
}
}
// withoutContext changes the signature of a function that receives a context and one argument to
// receive only the argument. When called the passed context is always context.Background().
func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
if fn == nil {
return nil
}
return func(t T) error {
return fn(context.Background(), t)
}
}
// GoMigration is a Go migration func that is run within a transaction.
//
// Deprecated: Use GoMigrationContext.
type GoMigration func(tx *sql.Tx) error
// GoMigrationNoTx is a Go migration func that is run outside a transaction.
//
// Deprecated: Use GoMigrationNoTxContext.
type GoMigrationNoTx func(db *sql.DB) error
// AddMigration adds Go migrations.
//
// Deprecated: Use AddMigrationContext.
func AddMigration(up, down GoMigration) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}
// AddNamedMigration adds named Go migrations.
//
// Deprecated: Use AddNamedMigrationContext.
func AddNamedMigration(filename string, up, down GoMigration) {
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}
// AddMigrationNoTx adds Go migrations that will be run outside transaction.
//
// Deprecated: Use AddMigrationNoTxContext.
func AddMigrationNoTx(up, down GoMigrationNoTx) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
}
// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
//
// Deprecated: Use AddNamedMigrationNoTxContext.
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
}

View File

@ -1,27 +1,30 @@
package gomigrations
package gomigrations_test
import (
"database/sql"
"path/filepath"
"testing"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
_ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata"
_ "modernc.org/sqlite"
)
func TestGoMigrationByOne(t *testing.T) {
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
t.Parallel()
check.NoError(t, goose.SetDialect("sqlite3"))
db, err := sql.Open("sqlite", ":memory:")
check.NoError(t, err)
dir := "testdata"
files, err := filepath.Glob(dir + "/*.go")
check.NoError(t, err)
upByOne := func(t *testing.T) int64 {
err = goose.UpByOne(db, dir)
t.Logf("err: %v %s", err, dir)
check.NoError(t, err)
version, err := goose.GetDBVersion(db)
check.NoError(t, err)
@ -42,6 +45,21 @@ func TestGoMigrationByOne(t *testing.T) {
check.NoError(t, err)
check.Number(t, version, len(files))
tables, err := ListTables(db)
check.NoError(t, err)
check.Equal(t, tables, []string{
"alpha",
"bravo",
"charlie",
"delta",
"echo",
"foxtrot",
"golf",
"goose_db_version",
"hotel",
"sqlite_sequence",
})
// Migrate all files down-by-one.
for i := len(files) - 1; i >= 0; i-- {
check.Number(t, downByOne(t), i)
@ -49,4 +67,31 @@ func TestGoMigrationByOne(t *testing.T) {
version, err = goose.GetDBVersion(db)
check.NoError(t, err)
check.Number(t, version, 0)
tables, err = ListTables(db)
check.NoError(t, err)
check.Equal(t, tables, []string{
"goose_db_version",
"sqlite_sequence",
})
}
func ListTables(db *sql.DB) ([]string, error) {
rows, err := db.Query(`SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
tables = append(tables, name)
}
if err := rows.Err(); err != nil {
return nil, err
}
return tables, nil
}

View File

@ -1,9 +1,12 @@
package gomigrations
import (
"context"
"database/sql"
"fmt"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
)
func init() {
@ -11,13 +14,19 @@ func init() {
}
func up001(tx *sql.Tx) error {
q := "CREATE TABLE foo (id INT, subid INT, name TEXT)"
_, err := tx.Exec(q)
return err
return createTable(tx, "alpha")
}
func down001(tx *sql.Tx) error {
q := "DROP TABLE IF EXISTS foo"
_, err := tx.Exec(q)
return dropTable(tx, "alpha")
}
func createTable(db database.DBTxConn, name string) error {
_, err := db.ExecContext(context.Background(), fmt.Sprintf("CREATE TABLE %s (id INTEGER)", name))
return err
}
func dropTable(db database.DBTxConn, name string) error {
_, err := db.ExecContext(context.Background(), fmt.Sprintf("DROP TABLE %s", name))
return err
}

View File

@ -11,7 +11,5 @@ func init() {
}
func up002(tx *sql.Tx) error {
q := "INSERT INTO foo VALUES (1, 1, 'Alice')"
_, err := tx.Exec(q)
return err
return createTable(tx, "bravo")
}

View File

@ -11,7 +11,5 @@ func init() {
}
func down003(tx *sql.Tx) error {
q := "TRUNCATE TABLE foo"
_, err := tx.Exec(q)
return err
return dropTable(tx, "bravo")
}

View File

@ -11,13 +11,9 @@ func init() {
}
func up005(db *sql.DB) error {
q := "CREATE TABLE users (id INT, email TEXT)"
_, err := db.Exec(q)
return err
return createTable(db, "charlie")
}
func down005(db *sql.DB) error {
q := "DROP TABLE IF EXISTS users"
_, err := db.Exec(q)
return err
return dropTable(db, "charlie")
}

View File

@ -11,7 +11,5 @@ func init() {
}
func up006(db *sql.DB) error {
q := "INSERT INTO users VALUES (1, 'admin@example.com')"
_, err := db.Exec(q)
return err
return createTable(db, "delta")
}

View File

@ -11,7 +11,5 @@ func init() {
}
func down007(db *sql.DB) error {
q := "TRUNCATE TABLE users"
_, err := db.Exec(q)
return err
return dropTable(db, "delta")
}

View File

@ -5,5 +5,5 @@ import (
)
func init() {
goose.AddMigration(nil, nil)
goose.AddMigrationNoTx(nil, nil)
}

View File

@ -0,0 +1,20 @@
package gomigrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(up009, down009)
}
func up009(ctx context.Context, tx *sql.Tx) error {
return createTable(tx, "echo")
}
func down009(ctx context.Context, tx *sql.Tx) error {
return dropTable(tx, "echo")
}

View File

@ -0,0 +1,16 @@
package gomigrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(up010, nil)
}
func up010(ctx context.Context, tx *sql.Tx) error {
return createTable(tx, "foxtrot")
}

View File

@ -0,0 +1,16 @@
package gomigrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(nil, down011)
}
func down011(ctx context.Context, tx *sql.Tx) error {
return dropTable(tx, "foxtrot")
}

View File

@ -0,0 +1,9 @@
package gomigrations
import (
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationContext(nil, nil)
}

View File

@ -0,0 +1,20 @@
package gomigrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTxContext(up013, down013)
}
func up013(ctx context.Context, db *sql.DB) error {
return createTable(db, "golf")
}
func down013(ctx context.Context, db *sql.DB) error {
return dropTable(db, "golf")
}

View File

@ -0,0 +1,16 @@
package gomigrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTxContext(up014, nil)
}
func up014(ctx context.Context, db *sql.DB) error {
return createTable(db, "hotel")
}

View File

@ -0,0 +1,16 @@
package gomigrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTxContext(nil, down015)
}
func down015(ctx context.Context, db *sql.DB) error {
return dropTable(db, "hotel")
}

View File

@ -0,0 +1,9 @@
package gomigrations
import (
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTxContext(nil, nil)
}