goose/migration.go

238 lines
5.8 KiB
Go

package goose
import (
"context"
"database/sql"
"errors"
"fmt"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/pressly/goose/v3/internal/sqlparser"
)
// MigrationRecord struct.
type MigrationRecord struct {
VersionID int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
// Migration struct.
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UseTx bool
UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx
noVersioning bool
}
func (m *Migration) String() string {
return fmt.Sprint(m.Source)
}
// Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error {
ctx := context.Background()
if err := m.run(ctx, db, true); err != nil {
return err
}
return nil
}
// Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error {
ctx := context.Background()
if err := m.run(ctx, db, false); err != nil {
return err
}
return nil
}
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
f, err := baseFS.Open(m.Source)
if err != nil {
return fmt.Errorf("ERROR %v: failed to open SQL migration file: %w", filepath.Base(m.Source), err)
}
defer f.Close()
statements, useTx, err := sqlparser.ParseSQLMigration(f, sqlparser.FromBool(direction), verbose)
if err != nil {
return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
}
start := time.Now()
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
}
finish := truncateDuration(time.Since(start))
if len(statements) > 0 {
log.Printf("OK %s (%s)\n", filepath.Base(m.Source), finish)
} else {
log.Printf("EMPTY %s (%s)\n", filepath.Base(m.Source), finish)
}
case ".go":
if !m.Registered {
return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
start := time.Now()
var empty bool
if m.UseTx {
// Run go-based migration inside a tx.
fn := m.DownFn
if direction {
fn = m.UpFn
}
empty = (fn == nil)
if err := runGoMigration(
ctx,
db,
fn,
m.Version,
direction,
!m.noVersioning,
); err != nil {
return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err)
}
} else {
// Run go-based migration outside a tx.
fn := m.DownFnNoTx
if direction {
fn = m.UpFnNoTx
}
empty = (fn == nil)
if err := runGoMigrationNoTx(
ctx,
db,
fn,
m.Version,
direction,
!m.noVersioning,
); err != nil {
return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err)
}
}
finish := truncateDuration(time.Since(start))
if !empty {
log.Printf("OK %s (%s)\n", filepath.Base(m.Source), finish)
} else {
log.Printf("EMPTY %s (%s)\n", filepath.Base(m.Source), finish)
}
}
return nil
}
func runGoMigrationNoTx(
ctx context.Context,
db *sql.DB,
fn GoMigrationNoTx,
version int64,
direction bool,
recordVersion bool,
) error {
if fn != nil {
// Run go migration function.
if err := fn(db); err != nil {
return fmt.Errorf("failed to run go migration: %w", err)
}
}
if recordVersion {
return insertOrDeleteVersionNoTx(ctx, db, version, direction)
}
return nil
}
func runGoMigration(
ctx context.Context,
db *sql.DB,
fn GoMigration,
version int64,
direction bool,
recordVersion bool,
) error {
if fn == nil && !recordVersion {
return nil
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
if fn != nil {
// Run go migration function.
if err := fn(tx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to run go migration: %w", err)
}
}
if recordVersion {
if err := insertOrDeleteVersion(ctx, tx, version, direction); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to update version: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
if direction {
return store.InsertVersion(ctx, tx, TableName(), version)
}
return store.DeleteVersion(ctx, tx, TableName(), version)
}
func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
if direction {
return store.InsertVersionNoTx(ctx, db, TableName(), version)
}
return store.DeleteVersionNoTx(ctx, db, TableName(), version)
}
// NumericComponent looks for migration scripts with names in the form:
// XXX_descriptivename.ext where XXX specifies the version number
// and ext specifies the type of migration
func NumericComponent(name string) (int64, error) {
base := filepath.Base(name)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no filename separator '_' found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
}
return n, e
}
func truncateDuration(d time.Duration) time.Duration {
for _, v := range []time.Duration{
time.Second,
time.Millisecond,
time.Microsecond,
} {
if d > v {
return d.Round(v / time.Duration(100))
}
}
return d
}