feat: expose new functions for setting context (#517)

pull/544/head
Ori Shalom 2023-06-16 16:34:33 +03:00 committed by GitHub
parent 949787e4da
commit 7dcddde25a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 128 additions and 40 deletions

29
down.go
View File

@ -1,12 +1,19 @@
package goose package goose
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
) )
// Down rolls back a single migration from the current version. // Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return DownContext(ctx, db, dir, opts...)
}
// DownContext rolls back a single migration from the current version.
func DownContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -21,9 +28,9 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
} }
currentVersion := migrations[len(migrations)-1].Version currentVersion := migrations[len(migrations)-1].Version
// Migrate only the latest migration down. // Migrate only the latest migration down.
return downToNoVersioning(db, migrations, currentVersion-1) return downToNoVersioning(ctx, db, migrations, currentVersion-1)
} }
currentVersion, err := GetDBVersion(db) currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil { if err != nil {
return err return err
} }
@ -31,11 +38,17 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
if err != nil { if err != nil {
return fmt.Errorf("no migration %v", currentVersion) return fmt.Errorf("no migration %v", currentVersion)
} }
return current.Down(db) return current.DownContext(ctx, db)
} }
// DownTo rolls back migrations to a specific version. // DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
ctx := context.Background()
return DownToContext(ctx, db, dir, version, opts...)
}
// DownToContext rolls back migrations to a specific version.
func DownToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -45,11 +58,11 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return err return err
} }
if option.noVersioning { if option.noVersioning {
return downToNoVersioning(db, migrations, version) return downToNoVersioning(ctx, db, migrations, version)
} }
for { for {
currentVersion, err := GetDBVersion(db) currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil { if err != nil {
return err return err
} }
@ -69,7 +82,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return nil return nil
} }
if err = current.Down(db); err != nil { if err = current.DownContext(ctx, db); err != nil {
return err return err
} }
} }
@ -77,7 +90,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// downToNoVersioning applies down migrations down to, but not including, the // downToNoVersioning applies down migrations down to, but not including, the
// target version. // target version.
func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64 var finalVersion int64
for i := len(migrations) - 1; i >= 0; i-- { for i := len(migrations) - 1; i >= 0; i-- {
if version >= migrations[i].Version { if version >= migrations[i].Version {
@ -85,7 +98,7 @@ func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error
break break
} }
migrations[i].noVersioning = true migrations[i].noVersioning = true
if err := migrations[i].Down(db); err != nil { if err := migrations[i].DownContext(ctx, db); err != nil {
return err return err
} }
} }

View File

@ -1,6 +1,7 @@
package goose package goose
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"io/fs" "io/fs"
@ -39,22 +40,34 @@ func SetBaseFS(fsys fs.FS) {
// Run runs a goose command. // Run runs a goose command.
func Run(command string, db *sql.DB, dir string, args ...string) error { func Run(command string, db *sql.DB, dir string, args ...string) error {
return run(command, db, dir, args) ctx := context.Background()
return RunContext(ctx, command, db, dir, args...)
} }
// Run runs a goose command with options. // RunContext runs a goose command.
func RunContext(ctx context.Context, command string, db *sql.DB, dir string, args ...string) error {
return run(ctx, command, db, dir, args)
}
// RunWithOptions runs a goose command with options.
func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(command, db, dir, args, options...) ctx := context.Background()
return RunWithOptionsContext(ctx, command, db, dir, args, options...)
} }
func run(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { // RunWithOptionsContext runs a goose command with options.
func RunWithOptionsContext(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(ctx, command, db, dir, args, options...)
}
func run(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
switch command { switch command {
case "up": case "up":
if err := Up(db, dir, options...); err != nil { if err := UpContext(ctx, db, dir, options...); err != nil {
return err return err
} }
case "up-by-one": case "up-by-one":
if err := UpByOne(db, dir, options...); err != nil { if err := UpByOneContext(ctx, db, dir, options...); err != nil {
return err return err
} }
case "up-to": case "up-to":
@ -66,7 +79,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil { if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0]) return fmt.Errorf("version must be a number (got '%s')", args[0])
} }
if err := UpTo(db, dir, version, options...); err != nil { if err := UpToContext(ctx, db, dir, version, options...); err != nil {
return err return err
} }
case "create": case "create":
@ -82,7 +95,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
return err return err
} }
case "down": case "down":
if err := Down(db, dir, options...); err != nil { if err := DownContext(ctx, db, dir, options...); err != nil {
return err return err
} }
case "down-to": case "down-to":
@ -94,7 +107,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil { if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0]) return fmt.Errorf("version must be a number (got '%s')", args[0])
} }
if err := DownTo(db, dir, version, options...); err != nil { if err := DownToContext(ctx, db, dir, version, options...); err != nil {
return err return err
} }
case "fix": case "fix":
@ -102,19 +115,19 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
return err return err
} }
case "redo": case "redo":
if err := Redo(db, dir, options...); err != nil { if err := RedoContext(ctx, db, dir, options...); err != nil {
return err return err
} }
case "reset": case "reset":
if err := Reset(db, dir, options...); err != nil { if err := ResetContext(ctx, db, dir, options...); err != nil {
return err return err
} }
case "status": case "status":
if err := Status(db, dir, options...); err != nil { if err := StatusContext(ctx, db, dir, options...); err != nil {
return err return err
} }
case "version": case "version":
if err := Version(db, dir, options...); err != nil { if err := VersionContext(ctx, db, dir, options...); err != nil {
return err return err
} }
default: default:

View File

@ -296,6 +296,12 @@ func versionFilter(v, current, target int64) bool {
// Create and initialize the DB version table if it doesn't exist. // Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) { func EnsureDBVersion(db *sql.DB) (int64, error) {
ctx := context.Background() ctx := context.Background()
return EnsureDBVersionContext(ctx, db)
}
// EnsureDBVersionContext retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
dbMigrations, err := store.ListMigrations(ctx, db, TableName()) dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil { if err != nil {
return 0, createVersionTable(ctx, db) return 0, createVersionTable(ctx, db)
@ -332,7 +338,7 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
// createVersionTable creates the db version table and inserts the // createVersionTable creates the db version table and inserts the
// initial 0 value into it. // initial 0 value into it.
func createVersionTable(ctx context.Context, db *sql.DB) error { func createVersionTable(ctx context.Context, db *sql.DB) error {
txn, err := db.Begin() txn, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -349,7 +355,13 @@ func createVersionTable(ctx context.Context, db *sql.DB) error {
// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error. // GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersion(db *sql.DB) (int64, error) { func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db) ctx := context.Background()
return GetDBVersionContext(ctx, db)
}
// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
version, err := EnsureDBVersionContext(ctx, db)
if err != nil { if err != nil {
return -1, err return -1, err
} }

View File

@ -40,6 +40,11 @@ func (m *Migration) String() string {
// Up runs an up migration. // Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error { func (m *Migration) Up(db *sql.DB) error {
ctx := context.Background() ctx := context.Background()
return m.UpContext(ctx, db)
}
// UpContext runs an up migration.
func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil { if err := m.run(ctx, db, true); err != nil {
return err return err
} }
@ -49,6 +54,11 @@ func (m *Migration) Up(db *sql.DB) error {
// Down runs a down migration. // Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error { func (m *Migration) Down(db *sql.DB) error {
ctx := context.Background() ctx := context.Background()
return m.DownContext(ctx, db)
}
// DownContext runs a down migration.
func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil { if err := m.run(ctx, db, false); err != nil {
return err return err
} }
@ -163,7 +173,7 @@ func runGoMigration(
if fn == nil && !recordVersion { if fn == nil && !recordVersion {
return nil return nil
} }
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)
} }

View File

@ -29,7 +29,7 @@ func runSQLMigration(
verboseInfo("Begin transaction") verboseInfo("Begin transaction")
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)
} }

13
redo.go
View File

@ -1,11 +1,18 @@
package goose package goose
import ( import (
"context"
"database/sql" "database/sql"
) )
// Redo rolls back the most recently applied migration, then runs it again. // Redo rolls back the most recently applied migration, then runs it again.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return RedoContext(ctx, db, dir, opts...)
}
// RedoContext rolls back the most recently applied migration, then runs it again.
func RedoContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -23,7 +30,7 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
} }
currentVersion = migrations[len(migrations)-1].Version currentVersion = migrations[len(migrations)-1].Version
} else { } else {
if currentVersion, err = GetDBVersion(db); err != nil { if currentVersion, err = GetDBVersionContext(ctx, db); err != nil {
return err return err
} }
} }
@ -34,10 +41,10 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
} }
current.noVersioning = option.noVersioning current.noVersioning = option.noVersioning
if err := current.Down(db); err != nil { if err := current.DownContext(ctx, db); err != nil {
return err return err
} }
if err := current.Up(db); err != nil { if err := current.UpContext(ctx, db); err != nil {
return err return err
} }
return nil return nil

View File

@ -10,6 +10,11 @@ import (
// Reset rolls back all migrations // Reset rolls back all migrations
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background() ctx := context.Background()
return ResetContext(ctx, db, dir, opts...)
}
// ResetContext rolls back all migrations
func ResetContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -19,7 +24,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
return fmt.Errorf("failed to collect migrations: %w", err) return fmt.Errorf("failed to collect migrations: %w", err)
} }
if option.noVersioning { if option.noVersioning {
return DownTo(db, dir, minVersion, opts...) return DownToContext(ctx, db, dir, minVersion, opts...)
} }
statuses, err := dbMigrationsStatus(ctx, db) statuses, err := dbMigrationsStatus(ctx, db)
@ -32,7 +37,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
if !statuses[migration.Version] { if !statuses[migration.Version] {
continue continue
} }
if err = migration.Down(db); err != nil { if err = migration.DownContext(ctx, db); err != nil {
return fmt.Errorf("failed to db-down: %w", err) return fmt.Errorf("failed to db-down: %w", err)
} }
} }

View File

@ -12,6 +12,11 @@ import (
// Status prints the status of all migrations. // Status prints the status of all migrations.
func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background() ctx := context.Background()
return StatusContext(ctx, db, dir, opts...)
}
// StatusContext prints the status of all migrations.
func StatusContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -30,7 +35,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
} }
// must ensure that the version table exists if we're running on a pristine DB // must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil { if _, err := EnsureDBVersionContext(ctx, db); err != nil {
return fmt.Errorf("failed to ensure DB version: %w", err) return fmt.Errorf("failed to ensure DB version: %w", err)
} }

32
up.go
View File

@ -35,6 +35,10 @@ func withApplyUpByOne() OptionsFunc {
// UpTo migrates up to a specific version. // UpTo migrates up to a specific version.
func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
ctx := context.Background() ctx := context.Background()
return UpToContext(ctx, db, dir, version, opts...)
}
func UpToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -53,10 +57,10 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// migration over and over. // migration over and over.
version = foundMigrations[0].Version version = foundMigrations[0].Version
} }
return upToNoVersioning(db, foundMigrations, version) return upToNoVersioning(ctx, db, foundMigrations, version)
} }
if _, err := EnsureDBVersion(db); err != nil { if _, err := EnsureDBVersionContext(ctx, db); err != nil {
return err return err
} }
dbMigrations, err := listAllDBVersions(ctx, db) dbMigrations, err := listAllDBVersions(ctx, db)
@ -103,7 +107,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
var current int64 var current int64
for _, m := range migrationsToApply { for _, m := range migrationsToApply {
if err := m.Up(db); err != nil { if err := m.UpContext(ctx, db); err != nil {
return err return err
} }
if option.applyUpByOne { if option.applyUpByOne {
@ -112,7 +116,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
current = m.Version current = m.Version
} }
if len(migrationsToApply) == 0 { if len(migrationsToApply) == 0 {
current, err = GetDBVersion(db) current, err = GetDBVersionContext(ctx, db)
if err != nil { if err != nil {
return err return err
} }
@ -130,14 +134,14 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// upToNoVersioning applies up migrations up to, and including, the // upToNoVersioning applies up migrations up to, and including, the
// target version. // target version.
func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { func upToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64 var finalVersion int64
for _, current := range migrations { for _, current := range migrations {
if current.Version > version { if current.Version > version {
break break
} }
current.noVersioning = true current.noVersioning = true
if err := current.Up(db); err != nil { if err := current.UpContext(ctx, db); err != nil {
return err return err
} }
finalVersion = current.Version finalVersion = current.Version
@ -148,13 +152,25 @@ func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
// Up applies all available migrations. // Up applies all available migrations.
func Up(db *sql.DB, dir string, opts ...OptionsFunc) error { func Up(db *sql.DB, dir string, opts ...OptionsFunc) error {
return UpTo(db, dir, maxVersion, opts...) ctx := context.Background()
return UpContext(ctx, db, dir, opts...)
}
// UpContext applies all available migrations.
func UpContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
return UpToContext(ctx, db, dir, maxVersion, opts...)
} }
// UpByOne migrates up by a single version. // UpByOne migrates up by a single version.
func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return UpByOneContext(ctx, db, dir, opts...)
}
// UpByOneContext migrates up by a single version.
func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
opts = append(opts, withApplyUpByOne()) opts = append(opts, withApplyUpByOne())
return UpTo(db, dir, maxVersion, opts...) return UpToContext(ctx, db, dir, maxVersion, opts...)
} }
// listAllDBVersions returns a list of all migrations, ordered ascending. // listAllDBVersions returns a list of all migrations, ordered ascending.

View File

@ -1,12 +1,19 @@
package goose package goose
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
) )
// Version prints the current version of the database. // Version prints the current version of the database.
func Version(db *sql.DB, dir string, opts ...OptionsFunc) error { func Version(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return VersionContext(ctx, db, dir, opts...)
}
// VersionContext prints the current version of the database.
func VersionContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{} option := &options{}
for _, f := range opts { for _, f := range opts {
f(option) f(option)
@ -24,7 +31,7 @@ func Version(db *sql.DB, dir string, opts ...OptionsFunc) error {
return nil return nil
} }
current, err := GetDBVersion(db) current, err := GetDBVersionContext(ctx, db)
if err != nil { if err != nil {
return err return err
} }