goose/internal/provider/migration.go

187 lines
4.4 KiB
Go

package provider
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"github.com/pressly/goose/v3/database"
)
type migration struct {
Source Source
// A migration is either a Go migration or a SQL migration, but never both.
//
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
// majority of the time migrations are incremental, so it is likely that the user will only want
// to run the last few migrations, and there is no need to parse ALL prior migrations.
//
// Exactly one of these fields will be set:
Go *goMigration
// -- OR --
SQL *sqlMigration
}
func (m *migration) useTx(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL.UseTx
case TypeGo:
if m.Go == nil || m.Go.isEmpty(direction) {
return false
}
if direction {
return m.Go.up.Run != nil
}
return m.Go.down.Run != nil
}
// This should never happen.
return false
}
func (m *migration) isEmpty(direction bool) bool {
switch m.Source.Type {
case TypeSQL:
return m.SQL == nil || m.SQL.isEmpty(direction)
case TypeGo:
return m.Go == nil || m.Go.isEmpty(direction)
}
return true
}
func (m *migration) filename() string {
return filepath.Base(m.Source.Path)
}
// run runs the migration inside of a transaction.
func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil {
return fmt.Errorf("tx: sql migration has not been parsed")
}
return m.SQL.run(ctx, tx, direction)
case TypeGo:
return m.Go.run(ctx, tx, direction)
}
// This should never happen.
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
// runNoTx runs the migration without a transaction.
func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil {
return fmt.Errorf("db: sql migration has not been parsed")
}
return m.SQL.run(ctx, db, direction)
case TypeGo:
return m.Go.runNoTx(ctx, db, direction)
}
// This should never happen.
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
// runConn runs the migration without a transaction using the provided connection.
func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error {
switch m.Source.Type {
case TypeSQL:
if m.SQL == nil {
return fmt.Errorf("conn: sql migration has not been parsed")
}
return m.SQL.run(ctx, conn, direction)
case TypeGo:
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
}
// This should never happen.
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Path))
}
type goMigration struct {
fullpath string
up, down *GoMigrationFunc
}
func (g *goMigration) isEmpty(direction bool) bool {
if g.up == nil && g.down == nil {
panic("go migration has no up or down")
}
if direction {
return g.up == nil
}
return g.down == nil
}
func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration {
return &goMigration{
fullpath: fullpath,
up: up,
down: down,
}
}
func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error {
if g == nil {
return nil
}
var fn func(context.Context, *sql.Tx) error
if direction && g.up != nil {
fn = g.up.Run
}
if !direction && g.down != nil {
fn = g.down.Run
}
if fn != nil {
return fn(ctx, tx)
}
return nil
}
func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
if g == nil {
return nil
}
var fn func(context.Context, *sql.DB) error
if direction && g.up != nil {
fn = g.up.RunNoTx
}
if !direction && g.down != nil {
fn = g.down.RunNoTx
}
if fn != nil {
return fn(ctx, db)
}
return nil
}
type sqlMigration struct {
UseTx bool
UpStatements []string
DownStatements []string
}
func (s *sqlMigration) isEmpty(direction bool) bool {
if direction {
return len(s.UpStatements) == 0
}
return len(s.DownStatements) == 0
}
func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
} else {
statements = s.DownStatements
}
for _, stmt := range statements {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}