mirror of https://github.com/pressly/goose.git
399 lines
11 KiB
Go
399 lines
11 KiB
Go
package goose
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pressly/goose/v3/internal/sqlparser"
|
|
)
|
|
|
|
// NewGoMigration creates a new Go migration.
|
|
//
|
|
// Both up and down functions may be nil, in which case the migration will be recorded in the
|
|
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
|
|
// a version without running any functions. See [GoFunc] for more details.
|
|
func NewGoMigration(version int64, up, down *GoFunc) *Migration {
|
|
m := &Migration{
|
|
Type: TypeGo,
|
|
Registered: true,
|
|
Version: version,
|
|
Next: -1, Previous: -1,
|
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
|
construct: true,
|
|
}
|
|
updateMode := func(f *GoFunc) *GoFunc {
|
|
// infer mode from function
|
|
if f.Mode == 0 {
|
|
if f.RunTx != nil && f.RunDB == nil {
|
|
f.Mode = TransactionEnabled
|
|
}
|
|
if f.RunTx == nil && f.RunDB != nil {
|
|
f.Mode = TransactionDisabled
|
|
}
|
|
// Always default to TransactionEnabled if both functions are nil. This is the most
|
|
// common use case.
|
|
if f.RunDB == nil && f.RunTx == nil {
|
|
f.Mode = TransactionEnabled
|
|
}
|
|
}
|
|
return f
|
|
}
|
|
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
|
|
// we will remove these fields in favor of [GoFunc].
|
|
//
|
|
// Note, this function does not do any validation. Validation is lazily done when the migration
|
|
// is registered.
|
|
if up != nil {
|
|
m.goUp = updateMode(up)
|
|
|
|
if up.RunDB != nil {
|
|
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
|
|
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
|
|
}
|
|
if up.RunTx != nil {
|
|
m.UseTx = true
|
|
m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error
|
|
m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error
|
|
}
|
|
}
|
|
if down != nil {
|
|
m.goDown = updateMode(down)
|
|
|
|
if down.RunDB != nil {
|
|
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
|
|
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
|
|
}
|
|
if down.RunTx != nil {
|
|
m.UseTx = true
|
|
m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error
|
|
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
|
|
}
|
|
}
|
|
return m
|
|
}
|
|
|
|
// Migration struct represents either a SQL or Go migration.
|
|
//
|
|
// Avoid constructing migrations manually, use [NewGoMigration] function.
|
|
type Migration struct {
|
|
Type MigrationType
|
|
Version int64
|
|
// Source is the path to the .sql script or .go file. It may be empty for Go migrations that
|
|
// have been registered globally and don't have a source file.
|
|
Source string
|
|
|
|
UpFnContext, DownFnContext GoMigrationContext
|
|
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
|
|
|
|
// These fields will be removed in a future major version. They are here for backwards
|
|
// compatibility and are an implementation detail.
|
|
Registered bool
|
|
UseTx bool
|
|
Next int64 // next version, or -1 if none
|
|
Previous int64 // previous version, -1 if none
|
|
|
|
// We still save the non-context versions in the struct in case someone is using them. Goose
|
|
// does not use these internally anymore in favor of the context-aware versions. These fields
|
|
// will be removed in a future major version.
|
|
|
|
UpFn GoMigration // Deprecated: use UpFnContext instead.
|
|
DownFn GoMigration // Deprecated: use DownFnContext instead.
|
|
UpFnNoTx GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead.
|
|
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
|
|
|
|
noVersioning bool
|
|
|
|
// These fields are used internally by goose and users are not expected to set them. Instead,
|
|
// use [NewGoMigration] to create a new go migration.
|
|
construct bool
|
|
goUp, goDown *GoFunc
|
|
|
|
sql sqlMigration
|
|
}
|
|
|
|
type sqlMigration struct {
|
|
// The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
|
|
// optimization to avoid parsing migrations that may never be needed. Typically, migrations are
|
|
// incremental, and users often run only the most recent ones, making parsing of prior
|
|
// migrations unnecessary in most cases.
|
|
Parsed bool
|
|
|
|
// Parsed must be set to true before the following fields are used.
|
|
UseTx bool
|
|
Up []string
|
|
Down []string
|
|
}
|
|
|
|
// GoFunc represents a Go migration function.
|
|
type GoFunc struct {
|
|
// Exactly one of these must be set, or both must be nil.
|
|
RunTx func(ctx context.Context, tx *sql.Tx) error
|
|
// -- OR --
|
|
RunDB func(ctx context.Context, db *sql.DB) error
|
|
|
|
// Mode is the transaction mode for the migration. When one of the run functions is set, the
|
|
// mode will be inferred from the function and the field is ignored. Users do not need to set
|
|
// this field when supplying a run function.
|
|
//
|
|
// If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil
|
|
// functions is to record a version in the version table without invoking a Go migration
|
|
// function.
|
|
//
|
|
// The only time this field is required is if BOTH run functions are nil AND you want to
|
|
// override the default transaction mode.
|
|
Mode TransactionMode
|
|
}
|
|
|
|
// TransactionMode represents the possible transaction modes for a migration.
|
|
type TransactionMode int
|
|
|
|
const (
|
|
TransactionEnabled TransactionMode = iota + 1
|
|
TransactionDisabled
|
|
)
|
|
|
|
func (m TransactionMode) String() string {
|
|
switch m {
|
|
case TransactionEnabled:
|
|
return "transaction_enabled"
|
|
case TransactionDisabled:
|
|
return "transaction_disabled"
|
|
default:
|
|
return fmt.Sprintf("unknown transaction mode (%d)", m)
|
|
}
|
|
}
|
|
|
|
// MigrationRecord struct.
|
|
//
|
|
// Deprecated: unused and will be removed in a future major version.
|
|
type MigrationRecord struct {
|
|
VersionID int64
|
|
TStamp time.Time
|
|
IsApplied bool // was this a result of up() or down()
|
|
}
|
|
|
|
func (m *Migration) String() string {
|
|
return fmt.Sprint(m.Source)
|
|
}
|
|
|
|
// Up runs an up migration.
|
|
func (m *Migration) Up(db *sql.DB) error {
|
|
ctx := context.Background()
|
|
return m.UpContext(ctx, db)
|
|
}
|
|
|
|
// UpContext runs an up migration.
|
|
func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
|
|
if err := m.run(ctx, db, true); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Down runs a down migration.
|
|
func (m *Migration) Down(db *sql.DB) error {
|
|
ctx := context.Background()
|
|
return m.DownContext(ctx, db)
|
|
}
|
|
|
|
// DownContext runs a down migration.
|
|
func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
|
|
if err := m.run(ctx, db, false); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
|
|
switch filepath.Ext(m.Source) {
|
|
case ".sql":
|
|
f, err := baseFS.Open(m.Source)
|
|
if err != nil {
|
|
return fmt.Errorf("ERROR %v: failed to open SQL migration file: %w", filepath.Base(m.Source), err)
|
|
}
|
|
defer f.Close()
|
|
|
|
statements, useTx, err := sqlparser.ParseSQLMigration(f, sqlparser.FromBool(direction), verbose)
|
|
if err != nil {
|
|
return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
|
|
}
|
|
|
|
start := time.Now()
|
|
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
|
|
return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
|
|
}
|
|
finish := truncateDuration(time.Since(start))
|
|
|
|
if len(statements) > 0 {
|
|
log.Printf("OK %s (%s)", filepath.Base(m.Source), finish)
|
|
} else {
|
|
log.Printf("EMPTY %s (%s)", filepath.Base(m.Source), finish)
|
|
}
|
|
|
|
case ".go":
|
|
if !m.Registered {
|
|
return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
|
|
}
|
|
start := time.Now()
|
|
var empty bool
|
|
if m.UseTx {
|
|
// Run go-based migration inside a tx.
|
|
fn := m.DownFnContext
|
|
if direction {
|
|
fn = m.UpFnContext
|
|
}
|
|
empty = (fn == nil)
|
|
if err := runGoMigration(
|
|
ctx,
|
|
db,
|
|
fn,
|
|
m.Version,
|
|
direction,
|
|
!m.noVersioning,
|
|
); err != nil {
|
|
return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err)
|
|
}
|
|
} else {
|
|
// Run go-based migration outside a tx.
|
|
fn := m.DownFnNoTxContext
|
|
if direction {
|
|
fn = m.UpFnNoTxContext
|
|
}
|
|
empty = (fn == nil)
|
|
if err := runGoMigrationNoTx(
|
|
ctx,
|
|
db,
|
|
fn,
|
|
m.Version,
|
|
direction,
|
|
!m.noVersioning,
|
|
); err != nil {
|
|
return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err)
|
|
}
|
|
}
|
|
finish := truncateDuration(time.Since(start))
|
|
if !empty {
|
|
log.Printf("OK %s (%s)", filepath.Base(m.Source), finish)
|
|
} else {
|
|
log.Printf("EMPTY %s (%s)", filepath.Base(m.Source), finish)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func runGoMigrationNoTx(
|
|
ctx context.Context,
|
|
db *sql.DB,
|
|
fn GoMigrationNoTxContext,
|
|
version int64,
|
|
direction bool,
|
|
recordVersion bool,
|
|
) error {
|
|
if fn != nil {
|
|
// Run go migration function.
|
|
if err := fn(ctx, db); err != nil {
|
|
return fmt.Errorf("failed to run go migration: %w", err)
|
|
}
|
|
}
|
|
if recordVersion {
|
|
return insertOrDeleteVersionNoTx(ctx, db, version, direction)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func runGoMigration(
|
|
ctx context.Context,
|
|
db *sql.DB,
|
|
fn GoMigrationContext,
|
|
version int64,
|
|
direction bool,
|
|
recordVersion bool,
|
|
) error {
|
|
if fn == nil && !recordVersion {
|
|
return nil
|
|
}
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
if fn != nil {
|
|
// Run go migration function.
|
|
if err := fn(ctx, tx); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("failed to run go migration: %w", err)
|
|
}
|
|
}
|
|
if recordVersion {
|
|
if err := insertOrDeleteVersion(ctx, tx, version, direction); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("failed to update version: %w", err)
|
|
}
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
|
|
if direction {
|
|
return store.InsertVersion(ctx, tx, TableName(), version)
|
|
}
|
|
return store.DeleteVersion(ctx, tx, TableName(), version)
|
|
}
|
|
|
|
func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
|
|
if direction {
|
|
return store.InsertVersionNoTx(ctx, db, TableName(), version)
|
|
}
|
|
return store.DeleteVersionNoTx(ctx, db, TableName(), version)
|
|
}
|
|
|
|
// NumericComponent parses the version from the migration file name.
|
|
//
|
|
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
|
|
// migration, either .sql or .go.
|
|
func NumericComponent(filename string) (int64, error) {
|
|
base := filepath.Base(filename)
|
|
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
|
|
return 0, errors.New("migration file does not have .sql or .go file extension")
|
|
}
|
|
idx := strings.Index(base, "_")
|
|
if idx < 0 {
|
|
return 0, errors.New("no filename separator '_' found")
|
|
}
|
|
n, err := strconv.ParseInt(base[:idx], 10, 64)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err)
|
|
}
|
|
if n < 1 {
|
|
return 0, errors.New("migration version must be greater than zero")
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func truncateDuration(d time.Duration) time.Duration {
|
|
for _, v := range []time.Duration{
|
|
time.Second,
|
|
time.Millisecond,
|
|
time.Microsecond,
|
|
} {
|
|
if d > v {
|
|
return d.Round(v / time.Duration(100))
|
|
}
|
|
}
|
|
return d
|
|
}
|
|
|
|
// ref returns a string that identifies the migration. This is used for logging and error messages.
|
|
func (m *Migration) ref() string {
|
|
return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
|
|
}
|