feat: Add provider HasPending method (#751)

pull/753/head
Michael Fridman 2024-04-21 13:05:54 -04:00 committed by GitHub
parent 7e96a2281a
commit 1ad801c2f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 326 additions and 44 deletions

View File

@ -42,6 +42,14 @@ test-packages:
test-packages-short:
go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
coverage-short:
go test ./ -test.short $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out
go tool cover -html=coverage.out
coverage:
go test ./ $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out
go tool cover -html=coverage.out
#
# Integration-related targets
#

View File

@ -4,14 +4,18 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"hash/crc64"
"math/rand"
"os"
"sort"
"sync"
"testing"
"testing/fstest"
"time"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testing/testdb"
"github.com/pressly/goose/v3/lock"
"github.com/stretchr/testify/require"
@ -406,6 +410,120 @@ func TestPostgresProviderLocking(t *testing.T) {
})
}
func TestPostgresHasPending(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping test in short mode.")
}
db, cleanup, err := testdb.NewPostgres()
require.NoError(t, err)
t.Cleanup(cleanup)
workers := 15
run := func(want bool) {
var g errgroup.Group
boolCh := make(chan bool, workers)
for i := 0; i < workers; i++ {
g.Go(func() error {
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
check.NoError(t, err)
hasPending, err := p.HasPending(context.Background())
if err != nil {
return err
}
boolCh <- hasPending
return nil
})
}
check.NoError(t, g.Wait())
close(boolCh)
// expect all values to be true
for hasPending := range boolCh {
check.Bool(t, hasPending, want)
}
}
t.Run("concurrent_has_pending", func(t *testing.T) {
run(true)
})
// apply all migrations
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
t.Run("concurrent_no_pending", func(t *testing.T) {
run(false)
})
// Add a new migration file
last := p.ListSources()[len(p.ListSources())-1]
newVersion := fmt.Sprintf("%d_new_migration.sql", last.Version+1)
fsys := fstest.MapFS{
newVersion: &fstest.MapFile{Data: []byte(`
-- +goose Up
SELECT pg_sleep_for('4 seconds');
`)},
}
lockID := int64(crc64.Checksum([]byte(t.Name()), crc64.MakeTable(crc64.ECMA)))
// Create a new provider with the new migration file
sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times.
require.NoError(t, err)
newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker))
check.NoError(t, err)
check.Number(t, len(newProvider.ListSources()), 1)
oldProvider := p
check.Number(t, len(oldProvider.ListSources()), 6)
var g errgroup.Group
g.Go(func() error {
hasPending, err := newProvider.HasPending(context.Background())
if err != nil {
return err
}
check.Bool(t, hasPending, true)
return nil
})
g.Go(func() error {
hasPending, err := oldProvider.HasPending(context.Background())
if err != nil {
return err
}
check.Bool(t, hasPending, false)
return nil
})
check.NoError(t, g.Wait())
// A new provider is running in the background with a session lock to simulate a long running
// migration. If older instances come up, they should not have any pending migrations and not be
// affected by the long running migration. Test the following scenario:
// https://github.com/pressly/goose/pull/507#discussion_r1266498077
g.Go(func() error {
_, err := newProvider.Up(context.Background())
return err
})
time.Sleep(1 * time.Second)
isLocked, err := existsPgLock(context.Background(), db, lockID)
check.NoError(t, err)
check.Bool(t, isLocked, true)
hasPending, err := oldProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
// Wait for the long running migration to finish
check.NoError(t, g.Wait())
// Check that the new migration was applied
hasPending, err = newProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
// The max version should be the new migration
currentVersion, err := newProvider.GetDBVersion(context.Background())
check.NoError(t, err)
check.Number(t, currentVersion, last.Version+1)
}
func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) {
q := `SELECT EXISTS(SELECT 1 FROM pg_locks WHERE locktype='advisory' AND ((classid::bigint<<32)|objid::bigint)=$1)`
row := db.QueryRowContext(ctx, q, lockID)

View File

@ -23,13 +23,15 @@ type Provider struct {
// database.
mu sync.Mutex
db *sql.DB
store database.Store
db *sql.DB
store database.Store
versionTableOnce sync.Once
fsys fs.FS
cfg config
// migrations are ordered by version in ascending order.
// migrations are ordered by version in ascending order. This list will never be empty and
// contains all migrations known to the provider.
migrations []*Migration
}
@ -49,8 +51,6 @@ type Provider struct {
// See [ProviderOption] for more information on configuring the provider.
//
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
@ -154,6 +154,14 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
return p.status(ctx)
}
// HasPending returns true if there are pending migrations to apply, otherwise, it returns false.
//
// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
// for pending migrations without blocking or being blocked by other operations.
func (p *Provider) HasPending(ctx context.Context) (bool, error) {
return p.hasPending(ctx)
}
// GetDBVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
// this method returns 4. If no migrations have been applied, it returns 0.
@ -214,12 +222,26 @@ func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bo
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
// empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, nil
}
return p.up(ctx, false, math.MaxInt64)
}
// UpByOne applies the next pending migration. If there is no next migration to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
// returns [ErrNoNextVersion].
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, ErrNoNextVersion
}
res, err := p.up(ctx, true, math.MaxInt64)
if err != nil {
return nil, err
@ -247,6 +269,13 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
// For example, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, nil
}
return p.up(ctx, false, version)
}
@ -303,7 +332,7 @@ func (p *Provider) up(
if version < 1 {
return nil, errInvalidVersion
}
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
@ -345,7 +374,7 @@ func (p *Provider) down(
byOne bool,
version int64,
) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
@ -404,7 +433,7 @@ func (p *Provider) apply(
if err != nil {
return nil, err
}
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
@ -436,8 +465,55 @@ func (p *Provider) apply(
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
}
func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
conn, cleanup, err := p.initialize(ctx, false)
if err != nil {
return false, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
// If versioning is disabled, we always have pending migrations.
if p.cfg.disableVersioning {
return true, nil
}
if p.cfg.allowMissing {
// List all migrations from the database.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return false, err
}
// If there are no migrations in the database, we have pending migrations.
if len(dbMigrations) == 0 {
return true, nil
}
applied := make(map[int64]bool, len(dbMigrations))
for _, m := range dbMigrations {
applied[m.Version] = true
}
// Iterate over all migrations and check if any are missing.
for _, m := range p.migrations {
if !applied[m.Version] {
return true, nil
}
}
return false, nil
}
// If out-of-order migrations are not allowed, we can optimize this by only checking whether the
// last migration the provider knows about is applied.
last := p.migrations[len(p.migrations)-1]
if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil {
if errors.Is(err, database.ErrVersionNotFound) {
return true, nil
}
return false, err
}
return false, nil
}
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
@ -478,7 +554,7 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64
if conn == nil {
var cleanup func() error
var err error
conn, cleanup, err = p.initialize(ctx)
conn, cleanup, err = p.initialize(ctx, true)
if err != nil {
return 0, err
}

View File

@ -14,6 +14,7 @@ import (
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/sqlparser"
"github.com/sethvargo/go-retry"
"go.uber.org/multierr"
)
@ -51,8 +52,14 @@ func (p *Provider) resolveUpMigrations(
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
var versionsMsg string
if len(collected) > 1 {
versionsMsg = "versions " + strings.Join(collected, ",")
} else {
versionsMsg = "version " + collected[0]
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): %s",
len(missingMigrations), msg, dbMaxVersion, versionsMsg,
)
}
for _, missingVersion := range missingMigrations {
@ -291,7 +298,7 @@ func beginTx(ctx context.Context, conn *sql.Conn, fn func(tx *sql.Tx) error) (re
return tx.Commit()
}
func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
func (p *Provider) initialize(ctx context.Context, useSessionLocker bool) (*sql.Conn, func() error, error) {
p.mu.Lock()
conn, err := p.db.Conn(ctx)
if err != nil {
@ -303,7 +310,8 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
p.mu.Unlock()
return conn.Close()
}
if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled {
if useSessionLocker && p.cfg.sessionLocker != nil && p.cfg.lockEnabled {
l := p.cfg.sessionLocker
if err := l.SessionLock(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
}
@ -320,7 +328,7 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
}
}
// If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't
// need the version table because no versions are being recorded.
// need the version table because no versions are being tracked.
if !p.cfg.disableVersioning {
if err := p.ensureVersionTable(ctx, conn); err != nil {
return nil, nil, multierr.Append(err, cleanup())
@ -329,36 +337,61 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
return conn, cleanup, nil
}
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
// existor is an interface that extends the Store interface with a method to check if the
// version table exists. This API is not stable and may change in the future.
type existor interface {
TableExists(context.Context, database.DBTxConn, string) (bool, error)
}
if e, ok := p.store.(existor); ok {
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
func (p *Provider) ensureVersionTable(
ctx context.Context,
conn *sql.Conn,
) (retErr error) {
// There are 2 optimizations here:
// - 1. We create the version table once per Provider instance.
// - 2. We retry the operation a few times in case the table is being created concurrently.
//
// Regarding item 2, certain goose operations, like HasPending, don't respect a SessionLocker.
// So, when goose is run for the first time in a multi-instance environment, it's possible that
// multiple instances will try to create the version table at the same time. This is why we
// retry this operation a few times. Best case, the table is created by one instance and all the
// other instances see that change immediately. Worst case, all instances try to create the
// table at the same time, but only one will succeed and the others will retry.
p.versionTableOnce.Do(func() {
retErr = p.tryEnsureVersionTable(ctx, conn)
})
return retErr
}
func (p *Provider) tryEnsureVersionTable(ctx context.Context, conn *sql.Conn) error {
b := retry.NewConstant(1 * time.Second)
b = retry.WithMaxRetries(3, b)
return retry.Do(ctx, b, func(ctx context.Context) error {
if e, ok := p.store.(interface {
TableExists(context.Context, database.DBTxConn, string) (bool, error)
}); ok {
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
}
if exists {
return nil
}
} else {
// This chicken-and-egg behavior is the fallback for all existing implementations of the
// Store interface. We check if the version table exists by querying for the initial
// version, but the table may not exist yet. It's important this runs outside of a
// transaction to avoid failing the transaction.
if res, err := p.store.GetMigration(ctx, conn, 0); err == nil && res != nil {
return nil
}
}
if exists {
return nil
if err := beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
}
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
}); err != nil {
// Mark the error as retryable so we can try again. It's possible that another instance
// is creating the table at the same time and the checks above will succeed on the next
// iteration.
return retry.RetryableError(fmt.Errorf("failed to create version table: %w", err))
}
} else {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
}
}
return beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
}
if p.cfg.disableVersioning {
return nil
}
return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0})
return nil
})
}

View File

@ -775,6 +775,53 @@ func TestProviderApply(t *testing.T) {
check.Bool(t, errors.Is(err, goose.ErrNotApplied), true)
}
func TestHasPending(t *testing.T) {
t.Parallel()
t.Run("allow_out_of_order", func(t *testing.T) {
ctx := context.Background()
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(),
goose.WithAllowOutofOrder(true),
)
check.NoError(t, err)
// Some migrations have been applied out of order.
_, err = p.ApplyVersion(ctx, 1, true)
check.NoError(t, err)
_, err = p.ApplyVersion(ctx, 3, true)
check.NoError(t, err)
hasPending, err := p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, true)
// Apply the missing migrations.
_, err = p.Up(ctx)
check.NoError(t, err)
// All migrations have been applied.
hasPending, err = p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, false)
})
t.Run("disallow_out_of_order", func(t *testing.T) {
ctx := context.Background()
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(),
goose.WithAllowOutofOrder(false),
)
check.NoError(t, err)
// Some migrations have been applied.
_, err = p.ApplyVersion(ctx, 1, true)
check.NoError(t, err)
_, err = p.ApplyVersion(ctx, 2, true)
check.NoError(t, err)
hasPending, err := p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, true)
_, err = p.Up(ctx)
check.NoError(t, err)
// All migrations have been applied.
hasPending, err = p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, false)
})
}
type customStoreSQLite3 struct {
database.Store
}