From e18fac69304ff175fda05d9b4185d45a94dee356 Mon Sep 17 00:00:00 2001 From: Ori Shalom Date: Thu, 29 Jun 2023 23:15:39 +0300 Subject: [PATCH] feat: add context-aware Go migrations (#534) --- create.go | 7 +++-- migrate.go | 88 +++++++++++++++++++++++++++++++++++++++++++--------- migration.go | 38 ++++++++++++++--------- 3 files changed, 100 insertions(+), 33 deletions(-) diff --git a/create.go b/create.go index 8634763..d9ec002 100644 --- a/create.go +++ b/create.go @@ -99,20 +99,21 @@ SELECT 'down SQL query'; var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations import ( + "context" "database/sql" "github.com/pressly/goose/v3" ) func init() { - goose.AddMigration(up{{.CamelName}}, down{{.CamelName}}) + goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } -func up{{.CamelName}}(tx *sql.Tx) error { +func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { // This code is executed when the migration is applied. return nil } -func down{{.CamelName}}(tx *sql.Tx) error { +func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { // This code is executed when the migration is rolled back. return nil } diff --git a/migrate.go b/migrate.go index 96c2ccc..01b4d9c 100644 --- a/migrate.go +++ b/migrate.go @@ -130,17 +130,35 @@ func (ms Migrations) String() string { // GoMigration is a Go migration func that is run within a transaction. type GoMigration func(tx *sql.Tx) error +// GoMigrationContext is a Go migration func that is run within a transaction and receives a context. +type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error + // GoMigrationNoTx is a Go migration func that is run outside a transaction. type GoMigrationNoTx func(db *sql.DB) error +// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context. +type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error + // AddMigration adds Go migrations. func AddMigration(up, down GoMigration) { _, filename, _, _ := runtime.Caller(1) - AddNamedMigration(filename, up, down) + // intentionally don't call to AddMigrationContext so each of these functions can calculate the filename correctly + AddNamedMigrationContext(filename, withContext(up), withContext(down)) +} + +// AddMigrationContext adds Go migrations. +func AddMigrationContext(up, down GoMigrationContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, up, down) } // AddNamedMigration adds named Go migrations. func AddNamedMigration(filename string, up, down GoMigration) { + AddNamedMigrationContext(filename, withContext(up), withContext(down)) +} + +// AddNamedMigrationContext adds named Go migrations. +func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { if err := register(filename, true, up, down, nil, nil); err != nil { panic(err) } @@ -148,12 +166,22 @@ func AddNamedMigration(filename string, up, down GoMigration) { // AddMigrationNoTx adds Go migrations that will be run outside transaction. func AddMigrationNoTx(up, down GoMigrationNoTx) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationNoTx(filename, up, down) + AddMigrationNoTxContext(withContext(up), withContext(down)) +} + +// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. +func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { + _, filename, _, _ := runtime.Caller(2) + AddNamedMigrationNoTxContext(filename, up, down) } // AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { + AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) +} + +// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. +func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { if err := register(filename, false, nil, nil, up, down); err != nil { panic(err) } @@ -162,8 +190,8 @@ func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { func register( filename string, useTx bool, - up, down GoMigration, - upNoTx, downNoTx GoMigrationNoTx, + up, down GoMigrationContext, + upNoTx, downNoTx GoMigrationNoTxContext, ) error { // Sanity check caller did not mix tx and non-tx based functions. if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { @@ -179,16 +207,23 @@ func register( } // Add to global as a registered migration. registeredGoMigrations[v] = &Migration{ - Version: v, - Next: -1, - Previous: -1, - Registered: true, - Source: filename, - UseTx: useTx, - UpFn: up, - DownFn: down, - UpFnNoTx: upNoTx, - DownFnNoTx: downNoTx, + Version: v, + Next: -1, + Previous: -1, + Registered: true, + Source: filename, + UseTx: useTx, + UpFnContext: up, + DownFnContext: down, + UpFnNoTxContext: upNoTx, + DownFnNoTxContext: downNoTx, + // These are deprecated and will be removed in the future. + // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. + // Goose does not use these internally anymore and instead uses the context versions. + UpFn: withoutContext(up), + DownFn: withoutContext(down), + UpFnNoTx: withoutContext(upNoTx), + DownFnNoTx: withoutContext(downNoTx), } return nil } @@ -378,3 +413,26 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { return version, nil } + +// withContext changes the signature of a function that receives one argument to receive a context and the argument. +func withContext[T any](fn func(T) error) func(context.Context, T) error { + if fn == nil { + return nil + } + + return func(ctx context.Context, t T) error { + return fn(t) + } +} + +// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument. +// When called the passed context is always context.Background(). +func withoutContext[T any](fn func(context.Context, T) error) func(T) error { + if fn == nil { + return nil + } + + return func(t T) error { + return fn(context.Background(), t) + } +} diff --git a/migration.go b/migration.go index 8d4362c..dcf0c61 100644 --- a/migration.go +++ b/migration.go @@ -22,15 +22,23 @@ type MigrationRecord struct { // Migration struct. type Migration struct { - Version int64 - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - Source string // path to .sql script or go file - Registered bool - UseTx bool + Version int64 + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + Source string // path to .sql script or go file + Registered bool + UseTx bool + + // These are deprecated and will be removed in the future. + // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. + // Goose does not use these internally anymore and instead uses the context versions. UpFn, DownFn GoMigration UpFnNoTx, DownFnNoTx GoMigrationNoTx - noVersioning bool + + // New functions with context + UpFnContext, DownFnContext GoMigrationContext + UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext + noVersioning bool } func (m *Migration) String() string { @@ -99,9 +107,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { var empty bool if m.UseTx { // Run go-based migration inside a tx. - fn := m.DownFn + fn := m.DownFnContext if direction { - fn = m.UpFn + fn = m.UpFnContext } empty = (fn == nil) if err := runGoMigration( @@ -116,9 +124,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { } } else { // Run go-based migration outside a tx. - fn := m.DownFnNoTx + fn := m.DownFnNoTxContext if direction { - fn = m.UpFnNoTx + fn = m.UpFnNoTxContext } empty = (fn == nil) if err := runGoMigrationNoTx( @@ -145,14 +153,14 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { func runGoMigrationNoTx( ctx context.Context, db *sql.DB, - fn GoMigrationNoTx, + fn GoMigrationNoTxContext, version int64, direction bool, recordVersion bool, ) error { if fn != nil { // Run go migration function. - if err := fn(db); err != nil { + if err := fn(ctx, db); err != nil { return fmt.Errorf("failed to run go migration: %w", err) } } @@ -165,7 +173,7 @@ func runGoMigrationNoTx( func runGoMigration( ctx context.Context, db *sql.DB, - fn GoMigration, + fn GoMigrationContext, version int64, direction bool, recordVersion bool, @@ -179,7 +187,7 @@ func runGoMigration( } if fn != nil { // Run go migration function. - if err := fn(tx); err != nil { + if err := fn(ctx, tx); err != nil { _ = tx.Rollback() return fmt.Errorf("failed to run go migration: %w", err) }