diff --git a/Makefile b/Makefile index f3fc783..8b859f4 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,14 @@ test-packages: test-packages-short: go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) +coverage-short: + go test ./ -test.short $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out + go tool cover -html=coverage.out + +coverage: + go test ./ $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out + go tool cover -html=coverage.out + # # Integration-related targets # diff --git a/internal/testing/integration/postgres_locking_test.go b/internal/testing/integration/postgres_locking_test.go index be99837..bfced40 100644 --- a/internal/testing/integration/postgres_locking_test.go +++ b/internal/testing/integration/postgres_locking_test.go @@ -4,14 +4,18 @@ import ( "context" "database/sql" "errors" + "fmt" + "hash/crc64" "math/rand" "os" "sort" "sync" "testing" + "testing/fstest" "time" "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/testing/testdb" "github.com/pressly/goose/v3/lock" "github.com/stretchr/testify/require" @@ -406,6 +410,120 @@ func TestPostgresProviderLocking(t *testing.T) { }) } +func TestPostgresHasPending(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + db, cleanup, err := testdb.NewPostgres() + require.NoError(t, err) + t.Cleanup(cleanup) + + workers := 15 + + run := func(want bool) { + var g errgroup.Group + boolCh := make(chan bool, workers) + for i := 0; i < workers; i++ { + g.Go(func() error { + p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres")) + check.NoError(t, err) + hasPending, err := p.HasPending(context.Background()) + if err != nil { + return err + } + boolCh <- hasPending + return nil + + }) + } + check.NoError(t, g.Wait()) + close(boolCh) + // expect all values to be true + for hasPending := range boolCh { + check.Bool(t, hasPending, want) + } + } + t.Run("concurrent_has_pending", func(t *testing.T) { + run(true) + }) + + // apply all migrations + p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres")) + check.NoError(t, err) + _, err = p.Up(context.Background()) + check.NoError(t, err) + + t.Run("concurrent_no_pending", func(t *testing.T) { + run(false) + }) + + // Add a new migration file + last := p.ListSources()[len(p.ListSources())-1] + newVersion := fmt.Sprintf("%d_new_migration.sql", last.Version+1) + fsys := fstest.MapFS{ + newVersion: &fstest.MapFile{Data: []byte(` +-- +goose Up +SELECT pg_sleep_for('4 seconds'); +`)}, + } + lockID := int64(crc64.Checksum([]byte(t.Name()), crc64.MakeTable(crc64.ECMA))) + // Create a new provider with the new migration file + sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times. + require.NoError(t, err) + newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker)) + check.NoError(t, err) + check.Number(t, len(newProvider.ListSources()), 1) + oldProvider := p + check.Number(t, len(oldProvider.ListSources()), 6) + + var g errgroup.Group + g.Go(func() error { + hasPending, err := newProvider.HasPending(context.Background()) + if err != nil { + return err + } + check.Bool(t, hasPending, true) + return nil + }) + g.Go(func() error { + hasPending, err := oldProvider.HasPending(context.Background()) + if err != nil { + return err + } + check.Bool(t, hasPending, false) + return nil + }) + check.NoError(t, g.Wait()) + + // A new provider is running in the background with a session lock to simulate a long running + // migration. If older instances come up, they should not have any pending migrations and not be + // affected by the long running migration. Test the following scenario: + // https://github.com/pressly/goose/pull/507#discussion_r1266498077 + g.Go(func() error { + _, err := newProvider.Up(context.Background()) + return err + }) + time.Sleep(1 * time.Second) + isLocked, err := existsPgLock(context.Background(), db, lockID) + check.NoError(t, err) + check.Bool(t, isLocked, true) + hasPending, err := oldProvider.HasPending(context.Background()) + check.NoError(t, err) + check.Bool(t, hasPending, false) + // Wait for the long running migration to finish + check.NoError(t, g.Wait()) + // Check that the new migration was applied + hasPending, err = newProvider.HasPending(context.Background()) + check.NoError(t, err) + check.Bool(t, hasPending, false) + // The max version should be the new migration + currentVersion, err := newProvider.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, last.Version+1) +} + func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) { q := `SELECT EXISTS(SELECT 1 FROM pg_locks WHERE locktype='advisory' AND ((classid::bigint<<32)|objid::bigint)=$1)` row := db.QueryRowContext(ctx, q, lockID) diff --git a/provider.go b/provider.go index 24a9eb5..65674d5 100644 --- a/provider.go +++ b/provider.go @@ -23,13 +23,15 @@ type Provider struct { // database. mu sync.Mutex - db *sql.DB - store database.Store + db *sql.DB + store database.Store + versionTableOnce sync.Once fsys fs.FS cfg config - // migrations are ordered by version in ascending order. + // migrations are ordered by version in ascending order. This list will never be empty and + // contains all migrations known to the provider. migrations []*Migration } @@ -49,8 +51,6 @@ type Provider struct { // See [ProviderOption] for more information on configuring the provider. // // Unless otherwise specified, all methods on Provider are safe for concurrent use. -// -// Experimental: This API is experimental and may change in the future. func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") @@ -154,6 +154,14 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { return p.status(ctx) } +// HasPending returns true if there are pending migrations to apply, otherwise, it returns false. +// +// Note, this method will not use a SessionLocker if one is configured. This allows callers to check +// for pending migrations without blocking or being blocked by other operations. +func (p *Provider) HasPending(ctx context.Context) (bool, error) { + return p.hasPending(ctx) +} + // GetDBVersion returns the highest version recorded in the database, regardless of the order in // which migrations were applied. For example, if migrations were applied out of order (1,4,2,3), // this method returns 4. If no migrations have been applied, it returns 0. @@ -214,12 +222,26 @@ func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bo // Up applies all pending migrations. If there are no new migrations to apply, this method returns // empty list and nil error. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { + hasPending, err := p.HasPending(ctx) + if err != nil { + return nil, err + } + if !hasPending { + return nil, nil + } return p.up(ctx, false, math.MaxInt64) } // UpByOne applies the next pending migration. If there is no next migration to apply, this method -// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. +// returns [ErrNoNextVersion]. func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { + hasPending, err := p.HasPending(ctx) + if err != nil { + return nil, err + } + if !hasPending { + return nil, ErrNoNextVersion + } res, err := p.up(ctx, true, math.MaxInt64) if err != nil { return nil, err @@ -247,6 +269,13 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { // For example, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9,10 will be applied. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + hasPending, err := p.HasPending(ctx) + if err != nil { + return nil, err + } + if !hasPending { + return nil, nil + } return p.up(ctx, false, version) } @@ -303,7 +332,7 @@ func (p *Provider) up( if version < 1 { return nil, errInvalidVersion } - conn, cleanup, err := p.initialize(ctx) + conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } @@ -345,7 +374,7 @@ func (p *Provider) down( byOne bool, version int64, ) (_ []*MigrationResult, retErr error) { - conn, cleanup, err := p.initialize(ctx) + conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } @@ -404,7 +433,7 @@ func (p *Provider) apply( if err != nil { return nil, err } - conn, cleanup, err := p.initialize(ctx) + conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } @@ -436,8 +465,55 @@ func (p *Provider) apply( return p.runMigrations(ctx, conn, []*Migration{m}, d, true) } +func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { + conn, cleanup, err := p.initialize(ctx, false) + if err != nil { + return false, fmt.Errorf("failed to initialize: %w", err) + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // If versioning is disabled, we always have pending migrations. + if p.cfg.disableVersioning { + return true, nil + } + if p.cfg.allowMissing { + // List all migrations from the database. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return false, err + } + // If there are no migrations in the database, we have pending migrations. + if len(dbMigrations) == 0 { + return true, nil + } + applied := make(map[int64]bool, len(dbMigrations)) + for _, m := range dbMigrations { + applied[m.Version] = true + } + // Iterate over all migrations and check if any are missing. + for _, m := range p.migrations { + if !applied[m.Version] { + return true, nil + } + } + return false, nil + } + // If out-of-order migrations are not allowed, we can optimize this by only checking whether the + // last migration the provider knows about is applied. + last := p.migrations[len(p.migrations)-1] + if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil { + if errors.Is(err, database.ErrVersionNotFound) { + return true, nil + } + return false, err + } + return false, nil +} + func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { - conn, cleanup, err := p.initialize(ctx) + conn, cleanup, err := p.initialize(ctx, true) if err != nil { return nil, fmt.Errorf("failed to initialize: %w", err) } @@ -478,7 +554,7 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64 if conn == nil { var cleanup func() error var err error - conn, cleanup, err = p.initialize(ctx) + conn, cleanup, err = p.initialize(ctx, true) if err != nil { return 0, err } diff --git a/provider_run.go b/provider_run.go index 4d07601..d6c4c9f 100644 --- a/provider_run.go +++ b/provider_run.go @@ -14,6 +14,7 @@ import ( "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/sqlparser" + "github.com/sethvargo/go-retry" "go.uber.org/multierr" ) @@ -51,8 +52,14 @@ func (p *Provider) resolveUpMigrations( if len(collected) > 1 { msg += "s" } - return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]", - len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","), + var versionsMsg string + if len(collected) > 1 { + versionsMsg = "versions " + strings.Join(collected, ",") + } else { + versionsMsg = "version " + collected[0] + } + return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): %s", + len(missingMigrations), msg, dbMaxVersion, versionsMsg, ) } for _, missingVersion := range missingMigrations { @@ -291,7 +298,7 @@ func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (re return tx.Commit() } -func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) { +func (p *Provider) initialize(ctx context.Context, useSessionLocker bool) (*sql.Conn, func() error, error) { p.mu.Lock() conn, err := p.db.Conn(ctx) if err != nil { @@ -303,7 +310,8 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err p.mu.Unlock() return conn.Close() } - if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled { + if useSessionLocker && p.cfg.sessionLocker != nil && p.cfg.lockEnabled { + l := p.cfg.sessionLocker if err := l.SessionLock(ctx, conn); err != nil { return nil, nil, multierr.Append(err, cleanup()) } @@ -320,7 +328,7 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err } } // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't - // need the version table because no versions are being recorded. + // need the version table because no versions are being tracked. if !p.cfg.disableVersioning { if err := p.ensureVersionTable(ctx, conn); err != nil { return nil, nil, multierr.Append(err, cleanup()) @@ -329,36 +337,61 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err return conn, cleanup, nil } -func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { - // existor is an interface that extends the Store interface with a method to check if the - // version table exists. This API is not stable and may change in the future. - type existor interface { - TableExists(context.Context, database.DBTxConn, string) (bool, error) - } - if e, ok := p.store.(existor); ok { - exists, err := e.TableExists(ctx, conn, p.store.Tablename()) - if err != nil { - return fmt.Errorf("failed to check if version table exists: %w", err) +func (p *Provider) ensureVersionTable( + ctx context.Context, + conn *sql.Conn, +) (retErr error) { + // There are 2 optimizations here: + // - 1. We create the version table once per Provider instance. + // - 2. We retry the operation a few times in case the table is being created concurrently. + // + // Regarding item 2, certain goose operations, like HasPending, don't respect a SessionLocker. + // So, when goose is run for the first time in a multi-instance environment, it's possible that + // multiple instances will try to create the version table at the same time. This is why we + // retry this operation a few times. Best case, the table is created by one instance and all the + // other instances see that change immediately. Worst case, all instances try to create the + // table at the same time, but only one will succeed and the others will retry. + p.versionTableOnce.Do(func() { + retErr = p.tryEnsureVersionTable(ctx, conn) + }) + return retErr +} + +func (p *Provider) tryEnsureVersionTable(ctx context.Context, conn *sql.Conn) error { + b := retry.NewConstant(1 * time.Second) + b = retry.WithMaxRetries(3, b) + return retry.Do(ctx, b, func(ctx context.Context) error { + if e, ok := p.store.(interface { + TableExists(context.Context, database.DBTxConn, string) (bool, error) + }); ok { + exists, err := e.TableExists(ctx, conn, p.store.Tablename()) + if err != nil { + return fmt.Errorf("failed to check if version table exists: %w", err) + } + if exists { + return nil + } + } else { + // This chicken-and-egg behavior is the fallback for all existing implementations of the + // Store interface. We check if the version table exists by querying for the initial + // version, but the table may not exist yet. It's important this runs outside of a + // transaction to avoid failing the transaction. + if res, err := p.store.GetMigration(ctx, conn, 0); err == nil && res != nil { + return nil + } } - if exists { - return nil + if err := beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := p.store.CreateVersionTable(ctx, tx); err != nil { + return err + } + return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0}) + }); err != nil { + // Mark the error as retryable so we can try again. It's possible that another instance + // is creating the table at the same time and the checks above will succeed on the next + // iteration. + return retry.RetryableError(fmt.Errorf("failed to create version table: %w", err)) } - } else { - // feat(mf): this is where we can check if the version table exists instead of trying to fetch - // from a table that may not exist. https://github.com/pressly/goose/issues/461 - res, err := p.store.GetMigration(ctx, conn, 0) - if err == nil && res != nil { - return nil - } - } - return beginTx(ctx, conn, func(tx *sql.Tx) error { - if err := p.store.CreateVersionTable(ctx, tx); err != nil { - return err - } - if p.cfg.disableVersioning { - return nil - } - return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0}) + return nil }) } diff --git a/provider_run_test.go b/provider_run_test.go index 914597f..59ec623 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -775,6 +775,53 @@ func TestProviderApply(t *testing.T) { check.Bool(t, errors.Is(err, goose.ErrNotApplied), true) } +func TestHasPending(t *testing.T) { + t.Parallel() + t.Run("allow_out_of_order", func(t *testing.T) { + ctx := context.Background() + p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(), + goose.WithAllowOutofOrder(true), + ) + check.NoError(t, err) + // Some migrations have been applied out of order. + _, err = p.ApplyVersion(ctx, 1, true) + check.NoError(t, err) + _, err = p.ApplyVersion(ctx, 3, true) + check.NoError(t, err) + hasPending, err := p.HasPending(ctx) + check.NoError(t, err) + check.Bool(t, hasPending, true) + // Apply the missing migrations. + _, err = p.Up(ctx) + check.NoError(t, err) + // All migrations have been applied. + hasPending, err = p.HasPending(ctx) + check.NoError(t, err) + check.Bool(t, hasPending, false) + }) + t.Run("disallow_out_of_order", func(t *testing.T) { + ctx := context.Background() + p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(), + goose.WithAllowOutofOrder(false), + ) + check.NoError(t, err) + // Some migrations have been applied. + _, err = p.ApplyVersion(ctx, 1, true) + check.NoError(t, err) + _, err = p.ApplyVersion(ctx, 2, true) + check.NoError(t, err) + hasPending, err := p.HasPending(ctx) + check.NoError(t, err) + check.Bool(t, hasPending, true) + _, err = p.Up(ctx) + check.NoError(t, err) + // All migrations have been applied. + hasPending, err = p.HasPending(ctx) + check.NoError(t, err) + check.Bool(t, hasPending, false) + }) +} + type customStoreSQLite3 struct { database.Store }