feat: expose new functions for setting context (#517)

pull/544/head
Ori Shalom 2023-06-16 16:34:33 +03:00 committed by GitHub
parent 949787e4da
commit 7dcddde25a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 128 additions and 40 deletions

29
down.go
View File

@ -1,12 +1,19 @@
package goose
import (
"context"
"database/sql"
"fmt"
)
// Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return DownContext(ctx, db, dir, opts...)
}
// DownContext rolls back a single migration from the current version.
func DownContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -21,9 +28,9 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion := migrations[len(migrations)-1].Version
// Migrate only the latest migration down.
return downToNoVersioning(db, migrations, currentVersion-1)
return downToNoVersioning(ctx, db, migrations, currentVersion-1)
}
currentVersion, err := GetDBVersion(db)
currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}
@ -31,11 +38,17 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
if err != nil {
return fmt.Errorf("no migration %v", currentVersion)
}
return current.Down(db)
return current.DownContext(ctx, db)
}
// DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
ctx := context.Background()
return DownToContext(ctx, db, dir, version, opts...)
}
// DownToContext rolls back migrations to a specific version.
func DownToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -45,11 +58,11 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return err
}
if option.noVersioning {
return downToNoVersioning(db, migrations, version)
return downToNoVersioning(ctx, db, migrations, version)
}
for {
currentVersion, err := GetDBVersion(db)
currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}
@ -69,7 +82,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return nil
}
if err = current.Down(db); err != nil {
if err = current.DownContext(ctx, db); err != nil {
return err
}
}
@ -77,7 +90,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// downToNoVersioning applies down migrations down to, but not including, the
// target version.
func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64
for i := len(migrations) - 1; i >= 0; i-- {
if version >= migrations[i].Version {
@ -85,7 +98,7 @@ func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error
break
}
migrations[i].noVersioning = true
if err := migrations[i].Down(db); err != nil {
if err := migrations[i].DownContext(ctx, db); err != nil {
return err
}
}

View File

@ -1,6 +1,7 @@
package goose
import (
"context"
"database/sql"
"fmt"
"io/fs"
@ -39,22 +40,34 @@ func SetBaseFS(fsys fs.FS) {
// Run runs a goose command.
func Run(command string, db *sql.DB, dir string, args ...string) error {
return run(command, db, dir, args)
ctx := context.Background()
return RunContext(ctx, command, db, dir, args...)
}
// Run runs a goose command with options.
// RunContext runs a goose command.
func RunContext(ctx context.Context, command string, db *sql.DB, dir string, args ...string) error {
return run(ctx, command, db, dir, args)
}
// RunWithOptions runs a goose command with options.
func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(command, db, dir, args, options...)
ctx := context.Background()
return RunWithOptionsContext(ctx, command, db, dir, args, options...)
}
func run(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
// RunWithOptionsContext runs a goose command with options.
func RunWithOptionsContext(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(ctx, command, db, dir, args, options...)
}
func run(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
switch command {
case "up":
if err := Up(db, dir, options...); err != nil {
if err := UpContext(ctx, db, dir, options...); err != nil {
return err
}
case "up-by-one":
if err := UpByOne(db, dir, options...); err != nil {
if err := UpByOneContext(ctx, db, dir, options...); err != nil {
return err
}
case "up-to":
@ -66,7 +79,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0])
}
if err := UpTo(db, dir, version, options...); err != nil {
if err := UpToContext(ctx, db, dir, version, options...); err != nil {
return err
}
case "create":
@ -82,7 +95,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
return err
}
case "down":
if err := Down(db, dir, options...); err != nil {
if err := DownContext(ctx, db, dir, options...); err != nil {
return err
}
case "down-to":
@ -94,7 +107,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0])
}
if err := DownTo(db, dir, version, options...); err != nil {
if err := DownToContext(ctx, db, dir, version, options...); err != nil {
return err
}
case "fix":
@ -102,19 +115,19 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
return err
}
case "redo":
if err := Redo(db, dir, options...); err != nil {
if err := RedoContext(ctx, db, dir, options...); err != nil {
return err
}
case "reset":
if err := Reset(db, dir, options...); err != nil {
if err := ResetContext(ctx, db, dir, options...); err != nil {
return err
}
case "status":
if err := Status(db, dir, options...); err != nil {
if err := StatusContext(ctx, db, dir, options...); err != nil {
return err
}
case "version":
if err := Version(db, dir, options...); err != nil {
if err := VersionContext(ctx, db, dir, options...); err != nil {
return err
}
default:

View File

@ -296,6 +296,12 @@ func versionFilter(v, current, target int64) bool {
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
ctx := context.Background()
return EnsureDBVersionContext(ctx, db)
}
// EnsureDBVersionContext retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil {
return 0, createVersionTable(ctx, db)
@ -332,7 +338,7 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
// createVersionTable creates the db version table and inserts the
// initial 0 value into it.
func createVersionTable(ctx context.Context, db *sql.DB) error {
txn, err := db.Begin()
txn, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -349,7 +355,13 @@ func createVersionTable(ctx context.Context, db *sql.DB) error {
// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db)
ctx := context.Background()
return GetDBVersionContext(ctx, db)
}
// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
version, err := EnsureDBVersionContext(ctx, db)
if err != nil {
return -1, err
}

View File

@ -40,6 +40,11 @@ func (m *Migration) String() string {
// Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error {
ctx := context.Background()
return m.UpContext(ctx, db)
}
// UpContext runs an up migration.
func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil {
return err
}
@ -49,6 +54,11 @@ func (m *Migration) Up(db *sql.DB) error {
// Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error {
ctx := context.Background()
return m.DownContext(ctx, db)
}
// DownContext runs a down migration.
func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil {
return err
}
@ -163,7 +173,7 @@ func runGoMigration(
if fn == nil && !recordVersion {
return nil
}
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

View File

@ -29,7 +29,7 @@ func runSQLMigration(
verboseInfo("Begin transaction")
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

13
redo.go
View File

@ -1,11 +1,18 @@
package goose
import (
"context"
"database/sql"
)
// Redo rolls back the most recently applied migration, then runs it again.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return RedoContext(ctx, db, dir, opts...)
}
// RedoContext rolls back the most recently applied migration, then runs it again.
func RedoContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -23,7 +30,7 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion = migrations[len(migrations)-1].Version
} else {
if currentVersion, err = GetDBVersion(db); err != nil {
if currentVersion, err = GetDBVersionContext(ctx, db); err != nil {
return err
}
}
@ -34,10 +41,10 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
current.noVersioning = option.noVersioning
if err := current.Down(db); err != nil {
if err := current.DownContext(ctx, db); err != nil {
return err
}
if err := current.Up(db); err != nil {
if err := current.UpContext(ctx, db); err != nil {
return err
}
return nil

View File

@ -10,6 +10,11 @@ import (
// Reset rolls back all migrations
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return ResetContext(ctx, db, dir, opts...)
}
// ResetContext rolls back all migrations
func ResetContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -19,7 +24,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
return fmt.Errorf("failed to collect migrations: %w", err)
}
if option.noVersioning {
return DownTo(db, dir, minVersion, opts...)
return DownToContext(ctx, db, dir, minVersion, opts...)
}
statuses, err := dbMigrationsStatus(ctx, db)
@ -32,7 +37,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
if !statuses[migration.Version] {
continue
}
if err = migration.Down(db); err != nil {
if err = migration.DownContext(ctx, db); err != nil {
return fmt.Errorf("failed to db-down: %w", err)
}
}

View File

@ -12,6 +12,11 @@ import (
// Status prints the status of all migrations.
func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return StatusContext(ctx, db, dir, opts...)
}
// StatusContext prints the status of all migrations.
func StatusContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -30,7 +35,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
// must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil {
if _, err := EnsureDBVersionContext(ctx, db); err != nil {
return fmt.Errorf("failed to ensure DB version: %w", err)
}

32
up.go
View File

@ -35,6 +35,10 @@ func withApplyUpByOne() OptionsFunc {
// UpTo migrates up to a specific version.
func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
ctx := context.Background()
return UpToContext(ctx, db, dir, version, opts...)
}
func UpToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -53,10 +57,10 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// migration over and over.
version = foundMigrations[0].Version
}
return upToNoVersioning(db, foundMigrations, version)
return upToNoVersioning(ctx, db, foundMigrations, version)
}
if _, err := EnsureDBVersion(db); err != nil {
if _, err := EnsureDBVersionContext(ctx, db); err != nil {
return err
}
dbMigrations, err := listAllDBVersions(ctx, db)
@ -103,7 +107,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
var current int64
for _, m := range migrationsToApply {
if err := m.Up(db); err != nil {
if err := m.UpContext(ctx, db); err != nil {
return err
}
if option.applyUpByOne {
@ -112,7 +116,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
current = m.Version
}
if len(migrationsToApply) == 0 {
current, err = GetDBVersion(db)
current, err = GetDBVersionContext(ctx, db)
if err != nil {
return err
}
@ -130,14 +134,14 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// upToNoVersioning applies up migrations up to, and including, the
// target version.
func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
func upToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64
for _, current := range migrations {
if current.Version > version {
break
}
current.noVersioning = true
if err := current.Up(db); err != nil {
if err := current.UpContext(ctx, db); err != nil {
return err
}
finalVersion = current.Version
@ -148,13 +152,25 @@ func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
// Up applies all available migrations.
func Up(db *sql.DB, dir string, opts ...OptionsFunc) error {
return UpTo(db, dir, maxVersion, opts...)
ctx := context.Background()
return UpContext(ctx, db, dir, opts...)
}
// UpContext applies all available migrations.
func UpContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
return UpToContext(ctx, db, dir, maxVersion, opts...)
}
// UpByOne migrates up by a single version.
func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return UpByOneContext(ctx, db, dir, opts...)
}
// UpByOneContext migrates up by a single version.
func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
opts = append(opts, withApplyUpByOne())
return UpTo(db, dir, maxVersion, opts...)
return UpToContext(ctx, db, dir, maxVersion, opts...)
}
// listAllDBVersions returns a list of all migrations, ordered ascending.

View File

@ -1,12 +1,19 @@
package goose
import (
"context"
"database/sql"
"fmt"
)
// Version prints the current version of the database.
func Version(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return VersionContext(ctx, db, dir, opts...)
}
// VersionContext prints the current version of the database.
func VersionContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
@ -24,7 +31,7 @@ func Version(db *sql.DB, dir string, opts ...OptionsFunc) error {
return nil
}
current, err := GetDBVersion(db)
current, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}