package goose import ( "context" "database/sql" "errors" "fmt" "io/fs" "math" "strconv" "strings" "sync" "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/controller" "github.com/pressly/goose/v3/internal/gooseutil" "github.com/pressly/goose/v3/internal/sqlparser" "go.uber.org/multierr" ) // Provider is a goose migration provider. type Provider struct { // mu protects all accesses to the provider and must be held when calling operations on the // database. mu sync.Mutex db *sql.DB store *controller.StoreController versionTableOnce sync.Once fsys fs.FS cfg config // migrations are ordered by version in ascending order. This list will never be empty and // contains all migrations known to the provider. migrations []*Migration } // NewProvider returns a new goose provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For // example, if the database dialect is "postgres", the database/sql driver could be // github.com/lib/pq or github.com/jackc/pgx. Each dialect has a corresponding [database.Dialect] // constant backed by a default [database.Store] implementation. For more advanced use cases, such // as using a custom table name or supplying a custom store implementation, see [WithStore]. // // fsys is the filesystem used to read migration files, but may be nil. Most users will want to use // [os.DirFS], os.DirFS("path/to/migrations"), to read migrations from the local filesystem. // However, it is possible to use a different "filesystem", such as [embed.FS] or filter out // migrations using [fs.Sub]. // // See [ProviderOption] for more information on configuring the provider. // // Unless otherwise specified, all methods on Provider are safe for concurrent use. func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") } if fsys == nil { fsys = noopFS{} } cfg := config{ registered: make(map[int64]*Migration), excludePaths: make(map[string]bool), excludeVersions: make(map[int64]bool), logger: &stdLogger{}, } for _, opt := range opts { if err := opt.apply(&cfg); err != nil { return nil, err } } // Allow users to specify a custom store implementation, but only if they don't specify a // dialect. If they specify a dialect, we'll use the default store implementation. if dialect == "" && cfg.store == nil { return nil, errors.New("dialect must not be empty") } if dialect != "" && cfg.store != nil { return nil, errors.New("dialect must be empty when using a custom store implementation") } var store database.Store if dialect != "" { var err error store, err = database.NewStore(dialect, DefaultTablename) if err != nil { return nil, err } } else { store = cfg.store } if store.Tablename() == "" { return nil, errors.New("invalid store implementation: table name must not be empty") } return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */) } func newProvider( db *sql.DB, store database.Store, fsys fs.FS, cfg config, global map[int64]*Migration, ) (*Provider, error) { // Collect migrations from the filesystem and merge with registered migrations. // // Note, we don't parse SQL migrations here. They are parsed lazily when required. // feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return // an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so // we should make it optional. filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions) if err != nil { return nil, err } versionToGoMigration := make(map[int64]*Migration) // Add user-registered Go migrations from the provider. for version, m := range cfg.registered { versionToGoMigration[version] = m } // Skip adding global Go migrations if explicitly disabled. if cfg.disableGlobalRegistry { // TODO(mf): let's add a warn-level log here to inform users if len(global) > 0. Would like // to add this once we're on go1.21 and leverage the new slog package. } else { for version, m := range global { if _, ok := versionToGoMigration[version]; ok { return nil, fmt.Errorf("global go migration conflicts with provider-registered go migration with version %d", version) } versionToGoMigration[version] = m } } // At this point we have all registered unique Go migrations (if any). We need to merge them // with SQL migrations from the filesystem. migrations, err := merge(filesystemSources, versionToGoMigration) if err != nil { return nil, err } if len(migrations) == 0 { return nil, ErrNoMigrations } return &Provider{ db: db, fsys: fsys, cfg: cfg, store: controller.NewStoreController(store), migrations: migrations, }, nil } // Status returns the status of all migrations, merging the list of migrations from the database and // filesystem. The returned items are ordered by version, in ascending order. func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { return p.status(ctx) } // HasPending returns true if there are pending migrations to apply, otherwise, it returns false. If // out-of-order migrations are disabled, yet some are detected, this method returns an error. // // Note, this method will not use a SessionLocker if one is configured. This allows callers to check // for pending migrations without blocking or being blocked by other operations. func (p *Provider) HasPending(ctx context.Context) (bool, error) { return p.hasPending(ctx) } // GetVersions returns the max database version and the target version to migrate to. // // Note, this method will not use a SessionLocker if one is configured. This allows callers to check // for versions without blocking or being blocked by other operations. func (p *Provider) GetVersions(ctx context.Context) (current, target int64, err error) { return p.getVersions(ctx) } // GetDBVersion returns the highest version recorded in the database, regardless of the order in // which migrations were applied. For example, if migrations were applied out of order (1,4,2,3), // this method returns 4. If no migrations have been applied, it returns 0. func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { if p.cfg.disableVersioning { return -1, errors.New("getting database version not supported when versioning is disabled") } return p.getDBMaxVersion(ctx, nil) } // ListSources returns a list of all migration sources known to the provider, sorted in ascending // order by version. The path field may be empty for manually registered migrations, such as Go // migrations registered using the [WithGoMigrations] option. func (p *Provider) ListSources() []*Source { sources := make([]*Source, 0, len(p.migrations)) for _, m := range p.migrations { sources = append(sources, &Source{ Type: m.Type, Path: m.Source, Version: m.Version, }) } return sources } // Ping attempts to ping the database to verify a connection is available. func (p *Provider) Ping(ctx context.Context) error { return p.db.PingContext(ctx) } // Close closes the database connection initially supplied to the provider. func (p *Provider) Close() error { return p.db.Close() } // ApplyVersion applies exactly one migration for the specified version. If there is no migration // available for the specified version, this method returns [ErrVersionNotFound]. If the migration // has already been applied, this method returns [ErrAlreadyApplied]. // // The direction parameter determines the migration direction: true for up migration and false for // down migration. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { res, err := p.apply(ctx, version, direction) if err != nil { return nil, err } // This should never happen, we must return exactly one result. if len(res) != 1 { versions := make([]string, 0, len(res)) for _, r := range res { versions = append(versions, strconv.FormatInt(r.Source.Version, 10)) } return nil, fmt.Errorf( "unexpected number of migrations applied running apply, expecting exactly one result: %v", strings.Join(versions, ","), ) } return res[0], nil } // Up applies all pending migrations. If there are no new migrations to apply, this method returns // empty list and nil error. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { hasPending, err := p.HasPending(ctx) if err != nil { return nil, err } if !hasPending { return nil, nil } return p.up(ctx, false, math.MaxInt64) } // UpByOne applies the next pending migration. If there is no next migration to apply, this method // returns [ErrNoNextVersion]. func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { hasPending, err := p.HasPending(ctx) if err != nil { return nil, err } if !hasPending { return nil, ErrNoNextVersion } res, err := p.up(ctx, true, math.MaxInt64) if err != nil { return nil, err } if len(res) == 0 { return nil, ErrNoNextVersion } // This should never happen, we must return exactly one result. if len(res) != 1 { versions := make([]string, 0, len(res)) for _, r := range res { versions = append(versions, strconv.FormatInt(r.Source.Version, 10)) } return nil, fmt.Errorf( "unexpected number of migrations applied running up-by-one, expecting exactly one result: %v", strings.Join(versions, ","), ) } return res[0], nil } // UpTo applies all pending migrations up to, and including, the specified version. If there are no // migrations to apply, this method returns empty list and nil error. // // For example, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9,10 will be applied. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { hasPending, err := p.HasPending(ctx) if err != nil { return nil, err } if !hasPending { return nil, nil } return p.up(ctx, false, version) } // Down rolls back the most recently applied migration. If there are no migrations to rollback, this // method returns [ErrNoNextVersion]. // // Note, migrations are rolled back in the order they were applied. And not in the reverse order of // the migration version. This only applies in scenarios where migrations are allowed to be applied // out of order. func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { res, err := p.down(ctx, true, 0) if err != nil { return nil, err } if len(res) == 0 { return nil, ErrNoNextVersion } // This should never happen, we must return exactly one result. if len(res) != 1 { versions := make([]string, 0, len(res)) for _, r := range res { versions = append(versions, strconv.FormatInt(r.Source.Version, 10)) } return nil, fmt.Errorf( "unexpected number of migrations applied running down, expecting exactly one result: %v", strings.Join(versions, ","), ) } return res[0], nil } // DownTo rolls back all migrations down to, but not including, the specified version. // // For example, if the current database version is 11,10,9... and the requested version is 9, only // migrations 11, 10 will be rolled back. // // Note, migrations are rolled back in the order they were applied. And not in the reverse order of // the migration version. This only applies in scenarios where migrations are allowed to be applied // out of order. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { if version < 0 { return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version) } return p.down(ctx, false, version) } // *** Internal methods *** func (p *Provider) up( ctx context.Context, byOne bool, version int64, ) (_ []*MigrationResult, retErr error) { if version < 1 { return nil, errInvalidVersion } conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) }() if len(p.migrations) == 0 { return nil, nil } var apply []*Migration if p.cfg.disableVersioning { if byOne { return nil, errors.New("up-by-one not supported when versioning is disabled") } apply = p.migrations } else { // optimize(mf): Listing all migrations from the database isn't great. // // The ideal implementation would be to query for the current max version and then apply // migrations greater than that version. However, a nice property of the current // implementation is that we can make stronger guarantees about unapplied migrations. // // In cases where users do not use out-of-order migrations, we want to surface an error if // there are older unapplied migrations. See https://github.com/pressly/goose/issues/761 for // more details. // // And in cases where users do use out-of-order migrations, we need to build a list of older // migrations that need to be applied, so we need to query for all migrations anyways. dbMigrations, err := p.store.ListMigrations(ctx, conn) if err != nil { return nil, err } if len(dbMigrations) == 0 { return nil, errMissingZeroVersion } versions, err := gooseutil.UpVersions( getVersionsFromMigrations(p.migrations), // fsys versions getVersionsFromListMigrations(dbMigrations), // db versions version, p.cfg.allowMissing, ) if err != nil { return nil, err } for _, v := range versions { m, err := p.getMigration(v) if err != nil { return nil, err } apply = append(apply, m) } } return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, byOne) } func (p *Provider) down( ctx context.Context, byOne bool, version int64, ) (_ []*MigrationResult, retErr error) { conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) }() if len(p.migrations) == 0 { return nil, nil } if p.cfg.disableVersioning { var downMigrations []*Migration if byOne { last := p.migrations[len(p.migrations)-1] downMigrations = []*Migration{last} } else { downMigrations = p.migrations } return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, byOne) } dbMigrations, err := p.store.ListMigrations(ctx, conn) if err != nil { return nil, err } if len(dbMigrations) == 0 { return nil, errMissingZeroVersion } // We never migrate the zero version down. if dbMigrations[0].Version == 0 { p.printf("no migrations to run, current version: 0") return nil, nil } var apply []*Migration for _, dbMigration := range dbMigrations { if dbMigration.Version <= version { break } m, err := p.getMigration(dbMigration.Version) if err != nil { return nil, err } apply = append(apply, m) } return p.runMigrations(ctx, conn, apply, sqlparser.DirectionDown, byOne) } func (p *Provider) apply( ctx context.Context, version int64, direction bool, ) (_ []*MigrationResult, retErr error) { if version < 1 { return nil, errInvalidVersion } m, err := p.getMigration(version) if err != nil { return nil, err } conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) }() result, err := p.store.GetMigration(ctx, conn, version) if err != nil && !errors.Is(err, database.ErrVersionNotFound) { return nil, err } // There are a few states here: // 1. direction is up // a. migration is applied, this is an error (ErrAlreadyApplied) // b. migration is not applied, apply it if direction && result != nil { return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) } // 2. direction is down // a. migration is applied, rollback // b. migration is not applied, this is an error (ErrNotApplied) if !direction && result == nil { return nil, fmt.Errorf("version %d: %w", version, ErrNotApplied) } d := sqlparser.DirectionDown if direction { d = sqlparser.DirectionUp } return p.runMigrations(ctx, conn, []*Migration{m}, d, true) } func (p *Provider) getVersions(ctx context.Context) (current, target int64, retErr error) { conn, cleanup, err := p.initialize(ctx, false) if err != nil { return -1, -1, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) }() target = p.migrations[len(p.migrations)-1].Version // If versioning is disabled, we always have pending migrations and the target version is the // last migration. if p.cfg.disableVersioning { return -1, target, nil } current, err = p.store.GetLatestVersion(ctx, conn) if err != nil { if errors.Is(err, database.ErrVersionNotFound) { return -1, target, errMissingZeroVersion } return -1, target, err } return current, target, nil } func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { conn, cleanup, err := p.initialize(ctx, false) if err != nil { return false, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) }() // If versioning is disabled, we always have pending migrations. if p.cfg.disableVersioning { return true, nil } // List all migrations from the database. Careful, optimizations here can lead to subtle bugs. // We have 2 important cases to consider: // // 1. Users have enabled out-of-order migrations, in which case we need to check if any // migrations are missing and report that there are pending migrations. Do not surface an // error because this is a valid state. // // 2. Users have disabled out-of-order migrations (default), in which case we need to check if all // migrations have been applied. We cannot check for the highest applied version because we lose the // ability to surface an error if an out-of-order migration was introduced. It would be silently // ignored and the user would not know that they have unapplied migrations. // // Maybe we could consider adding a flag to the provider such as IgnoreMissing, which would // allow silently ignoring missing migrations. This would be useful for users that have built // checks that prevent missing migrations from being introduced. dbMigrations, err := p.store.ListMigrations(ctx, conn) if err != nil { return false, err } apply, err := gooseutil.UpVersions( getVersionsFromMigrations(p.migrations), // fsys versions getVersionsFromListMigrations(dbMigrations), // db versions math.MaxInt64, p.cfg.allowMissing, ) if err != nil { return false, err } return len(apply) > 0, nil } func getVersionsFromMigrations(in []*Migration) []int64 { out := make([]int64, 0, len(in)) for _, m := range in { out = append(out, m.Version) } return out } func getVersionsFromListMigrations(in []*database.ListMigrationsResult) []int64 { out := make([]int64, 0, len(in)) for _, m := range in { out = append(out, m.Version) } return out } func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) }() status := make([]*MigrationStatus, 0, len(p.migrations)) for _, m := range p.migrations { migrationStatus := &MigrationStatus{ Source: &Source{ Type: m.Type, Path: m.Source, Version: m.Version, }, State: StatePending, } // If versioning is disabled, we can't check the database for applied migrations, so we // assume all migrations are pending. if !p.cfg.disableVersioning { dbResult, err := p.store.GetMigration(ctx, conn, m.Version) if err != nil && !errors.Is(err, database.ErrVersionNotFound) { return nil, err } if dbResult != nil { migrationStatus.State = StateApplied migrationStatus.AppliedAt = dbResult.Timestamp } } status = append(status, migrationStatus) } return status, nil } // getDBMaxVersion returns the highest version recorded in the database, regardless of the order in // which migrations were applied. conn may be nil, in which case a connection is initialized. func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) { if conn == nil { var cleanup func() error var err error conn, cleanup, err = p.initialize(ctx, true) if err != nil { return 0, err } defer func() { retErr = multierr.Append(retErr, cleanup()) }() } latest, err := p.store.GetLatestVersion(ctx, conn) if err != nil { if errors.Is(err, database.ErrVersionNotFound) { return 0, errMissingZeroVersion } return -1, err } return latest, nil }