feat: add context-aware Go migrations (#534)

pull/530/merge
Ori Shalom 2023-06-29 23:15:39 +03:00 committed by GitHub
parent 7d9fbafd99
commit e18fac6930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 100 additions and 33 deletions

View File

@ -99,20 +99,21 @@ SELECT 'down SQL query';
var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations
import ( import (
"context"
"database/sql" "database/sql"
"github.com/pressly/goose/v3" "github.com/pressly/goose/v3"
) )
func init() { func init() {
goose.AddMigration(up{{.CamelName}}, down{{.CamelName}}) goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}})
} }
func up{{.CamelName}}(tx *sql.Tx) error { func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error {
// This code is executed when the migration is applied. // This code is executed when the migration is applied.
return nil return nil
} }
func down{{.CamelName}}(tx *sql.Tx) error { func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error {
// This code is executed when the migration is rolled back. // This code is executed when the migration is rolled back.
return nil return nil
} }

View File

@ -130,17 +130,35 @@ func (ms Migrations) String() string {
// GoMigration is a Go migration func that is run within a transaction. // GoMigration is a Go migration func that is run within a transaction.
type GoMigration func(tx *sql.Tx) error type GoMigration func(tx *sql.Tx) error
// GoMigrationContext is a Go migration func that is run within a transaction and receives a context.
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error
// GoMigrationNoTx is a Go migration func that is run outside a transaction. // GoMigrationNoTx is a Go migration func that is run outside a transaction.
type GoMigrationNoTx func(db *sql.DB) error type GoMigrationNoTx func(db *sql.DB) error
// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context.
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error
// AddMigration adds Go migrations. // AddMigration adds Go migrations.
func AddMigration(up, down GoMigration) { func AddMigration(up, down GoMigration) {
_, filename, _, _ := runtime.Caller(1) _, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down) // intentionally don't call to AddMigrationContext so each of these functions can calculate the filename correctly
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}
// AddMigrationContext adds Go migrations.
func AddMigrationContext(up, down GoMigrationContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, up, down)
} }
// AddNamedMigration adds named Go migrations. // AddNamedMigration adds named Go migrations.
func AddNamedMigration(filename string, up, down GoMigration) { func AddNamedMigration(filename string, up, down GoMigration) {
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}
// AddNamedMigrationContext adds named Go migrations.
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
if err := register(filename, true, up, down, nil, nil); err != nil { if err := register(filename, true, up, down, nil, nil); err != nil {
panic(err) panic(err)
} }
@ -148,12 +166,22 @@ func AddNamedMigration(filename string, up, down GoMigration) {
// AddMigrationNoTx adds Go migrations that will be run outside transaction. // AddMigrationNoTx adds Go migrations that will be run outside transaction.
func AddMigrationNoTx(up, down GoMigrationNoTx) { func AddMigrationNoTx(up, down GoMigrationNoTx) {
_, filename, _, _ := runtime.Caller(1) AddMigrationNoTxContext(withContext(up), withContext(down))
AddNamedMigrationNoTx(filename, up, down) }
// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
_, filename, _, _ := runtime.Caller(2)
AddNamedMigrationNoTxContext(filename, up, down)
} }
// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. // AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
}
// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
if err := register(filename, false, nil, nil, up, down); err != nil { if err := register(filename, false, nil, nil, up, down); err != nil {
panic(err) panic(err)
} }
@ -162,8 +190,8 @@ func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
func register( func register(
filename string, filename string,
useTx bool, useTx bool,
up, down GoMigration, up, down GoMigrationContext,
upNoTx, downNoTx GoMigrationNoTx, upNoTx, downNoTx GoMigrationNoTxContext,
) error { ) error {
// Sanity check caller did not mix tx and non-tx based functions. // Sanity check caller did not mix tx and non-tx based functions.
if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) {
@ -179,16 +207,23 @@ func register(
} }
// Add to global as a registered migration. // Add to global as a registered migration.
registeredGoMigrations[v] = &Migration{ registeredGoMigrations[v] = &Migration{
Version: v, Version: v,
Next: -1, Next: -1,
Previous: -1, Previous: -1,
Registered: true, Registered: true,
Source: filename, Source: filename,
UseTx: useTx, UseTx: useTx,
UpFn: up, UpFnContext: up,
DownFn: down, DownFnContext: down,
UpFnNoTx: upNoTx, UpFnNoTxContext: upNoTx,
DownFnNoTx: downNoTx, DownFnNoTxContext: downNoTx,
// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
UpFn: withoutContext(up),
DownFn: withoutContext(down),
UpFnNoTx: withoutContext(upNoTx),
DownFnNoTx: withoutContext(downNoTx),
} }
return nil return nil
} }
@ -378,3 +413,26 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
return version, nil return version, nil
} }
// withContext changes the signature of a function that receives one argument to receive a context and the argument.
func withContext[T any](fn func(T) error) func(context.Context, T) error {
if fn == nil {
return nil
}
return func(ctx context.Context, t T) error {
return fn(t)
}
}
// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument.
// When called the passed context is always context.Background().
func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
if fn == nil {
return nil
}
return func(t T) error {
return fn(context.Background(), t)
}
}

View File

@ -22,15 +22,23 @@ type MigrationRecord struct {
// Migration struct. // Migration struct.
type Migration struct { type Migration struct {
Version int64 Version int64
Next int64 // next version, or -1 if none Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file Source string // path to .sql script or go file
Registered bool Registered bool
UseTx bool UseTx bool
// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
UpFn, DownFn GoMigration UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx UpFnNoTx, DownFnNoTx GoMigrationNoTx
noVersioning bool
// New functions with context
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
noVersioning bool
} }
func (m *Migration) String() string { func (m *Migration) String() string {
@ -99,9 +107,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
var empty bool var empty bool
if m.UseTx { if m.UseTx {
// Run go-based migration inside a tx. // Run go-based migration inside a tx.
fn := m.DownFn fn := m.DownFnContext
if direction { if direction {
fn = m.UpFn fn = m.UpFnContext
} }
empty = (fn == nil) empty = (fn == nil)
if err := runGoMigration( if err := runGoMigration(
@ -116,9 +124,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
} }
} else { } else {
// Run go-based migration outside a tx. // Run go-based migration outside a tx.
fn := m.DownFnNoTx fn := m.DownFnNoTxContext
if direction { if direction {
fn = m.UpFnNoTx fn = m.UpFnNoTxContext
} }
empty = (fn == nil) empty = (fn == nil)
if err := runGoMigrationNoTx( if err := runGoMigrationNoTx(
@ -145,14 +153,14 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
func runGoMigrationNoTx( func runGoMigrationNoTx(
ctx context.Context, ctx context.Context,
db *sql.DB, db *sql.DB,
fn GoMigrationNoTx, fn GoMigrationNoTxContext,
version int64, version int64,
direction bool, direction bool,
recordVersion bool, recordVersion bool,
) error { ) error {
if fn != nil { if fn != nil {
// Run go migration function. // Run go migration function.
if err := fn(db); err != nil { if err := fn(ctx, db); err != nil {
return fmt.Errorf("failed to run go migration: %w", err) return fmt.Errorf("failed to run go migration: %w", err)
} }
} }
@ -165,7 +173,7 @@ func runGoMigrationNoTx(
func runGoMigration( func runGoMigration(
ctx context.Context, ctx context.Context,
db *sql.DB, db *sql.DB,
fn GoMigration, fn GoMigrationContext,
version int64, version int64,
direction bool, direction bool,
recordVersion bool, recordVersion bool,
@ -179,7 +187,7 @@ func runGoMigration(
} }
if fn != nil { if fn != nil {
// Run go migration function. // Run go migration function.
if err := fn(tx); err != nil { if err := fn(ctx, tx); err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return fmt.Errorf("failed to run go migration: %w", err) return fmt.Errorf("failed to run go migration: %w", err)
} }