refactor: add TableExists assertion support and improve docs (#641)

This commit is contained in:
Michael Fridman 2023-11-12 11:01:28 -05:00 committed by GitHub
parent c5e0d3cffc
commit 9e6ef20c4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 85 additions and 44 deletions

View File

@ -21,7 +21,8 @@ type Store interface {
Tablename() string
// CreateVersionTable creates the version table. This table is used to record applied
// migrations.
// migrations. When creating the table, the implementation must also insert a row for the
// initial version (0).
CreateVersionTable(ctx context.Context, db DBTxConn) error
// Insert inserts a version id into the version table.
@ -35,7 +36,9 @@ type Store interface {
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error.
// there are no migrations, return empty slice with no error. Typically this method will return
// at least one migration, because the initial version (0) is always inserted into the version
// table when it is created.
ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error)
// TODO(mf): remove this method once the Provider is public and a custom Store can be used.

View File

@ -89,12 +89,12 @@ func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Con
return nil
}
/*
TODO(mf): provide users with some documentation on how they can unlock the session
docs(md): provide users with some documentation on how they can unlock the session
manually.
This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will
release all session level advisory locks held by the current session. This function is
implicitly invoked at session end, even if the client disconnects ungracefully.
release all session level advisory locks held by the current session. It is implicitly
invoked at session end, even if the client disconnects ungracefully.
Here is output from a session that has a lock held:

View File

@ -391,3 +391,8 @@ func truncateDuration(d time.Duration) time.Duration {
}
return d
}
// ref returns a string that identifies the migration. This is used for logging and error messages.
func (m *Migration) ref() string {
return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
}

View File

@ -297,7 +297,7 @@ func (p *Provider) up(
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
@ -339,7 +339,7 @@ func (p *Provider) down(
) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
@ -397,7 +397,7 @@ func (p *Provider) apply(
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
@ -422,7 +422,7 @@ func (p *Provider) apply(
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
@ -455,7 +455,7 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err
func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())

View File

@ -313,8 +313,8 @@ func TestCheckMissingMigrations(t *testing.T) {
}
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 2)
check.Number(t, got[1].versionID, 6)
check.Number(t, got[0], 2)
check.Number(t, got[1], 6)
// Sanity check.
check.Number(t, len(checkMissingMigrations(nil, nil)), 0)
@ -333,8 +333,8 @@ func TestCheckMissingMigrations(t *testing.T) {
}
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 3)
check.Number(t, got[1].versionID, 4)
check.Number(t, got[0], 3)
check.Number(t, got[1], 4)
})
}

View File

@ -3,7 +3,6 @@ package goose
import (
"errors"
"fmt"
"path/filepath"
)
var (
@ -32,9 +31,8 @@ type PartialError struct {
}
func (e *PartialError) Error() string {
filename := "(file unknown)"
if e.Failed != nil && e.Failed.Source.Path != "" {
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path))
}
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
return fmt.Sprintf(
"partial migration error (type:%s,version:%d): %v",
e.Failed.Source.Type, e.Failed.Source.Version, e.Err,
)
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
"time"
@ -43,7 +44,7 @@ func (p *Provider) resolveUpMigrations(
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, fmt.Sprintf("%d", v.versionID))
collected = append(collected, strconv.FormatInt(v, 10))
}
msg := "migration"
if len(collected) > 1 {
@ -53,8 +54,8 @@ func (p *Provider) resolveUpMigrations(
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
for _, missingVersion := range missingMigrations {
m, err := p.getMigration(missingVersion)
if err != nil {
return nil, err
}
@ -141,7 +142,7 @@ func (p *Provider) runMigrations(
for _, m := range apply {
if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
return nil, err
return nil, fmt.Errorf("failed to prepare migration %s: %w", m.ref(), err)
}
}
@ -301,11 +302,26 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
}
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
// existor is an interface that extends the Store interface with a method to check if the
// version table exists. This API is not stable and may change in the future.
type existor interface {
TableExists(context.Context, database.DBTxConn, string) (bool, error)
}
if e, ok := p.store.(existor); ok {
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
}
if exists {
return nil
}
} else {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
}
}
return beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
@ -318,16 +334,12 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE
})
}
type missingMigration struct {
versionID int64
}
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
// migration is one that has a version less than the max version in the database.
func checkMissingMigrations(
dbMigrations []*database.ListMigrationsResult,
fsMigrations []*Migration,
) []missingMigration {
) []int64 {
existing := make(map[int64]bool)
var dbMaxVersion int64
for _, m := range dbMigrations {
@ -336,17 +348,14 @@ func checkMissingMigrations(
dbMaxVersion = m.Version
}
}
var missing []missingMigration
var missing []int64
for _, m := range fsMigrations {
version := m.Version
if !existing[version] && version < dbMaxVersion {
missing = append(missing, missingMigration{
versionID: version,
})
if !existing[m.Version] && m.Version < dbMaxVersion {
missing = append(missing, m.Version)
}
}
sort.Slice(missing, func(i, j int) bool {
return missing[i].versionID < missing[j].versionID
return missing[i] < missing[j]
})
return missing
}

View File

@ -17,6 +17,7 @@ import (
"testing/fstest"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
"github.com/pressly/goose/v3/lock"
@ -31,7 +32,7 @@ func TestProviderRun(t *testing.T) {
check.NoError(t, db.Close())
_, err := p.Up(context.Background())
check.HasError(t, err)
check.Equal(t, err.Error(), "sql: database is closed")
check.Equal(t, err.Error(), "failed to initialize: sql: database is closed")
})
t.Run("ping_and_close", func(t *testing.T) {
p, _ := newProviderWithDB(t)
@ -324,7 +325,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
check.NoError(t, err)
_, err = p.Up(ctx)
check.HasError(t, err)
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
check.Contains(t, err.Error(), "partial migration error (type:sql,version:2)")
var expected *goose.PartialError
check.Bool(t, errors.As(err, &expected), true)
// Check Err field
@ -723,6 +724,32 @@ func TestSQLiteSharedCache(t *testing.T) {
})
}
func TestCustomStoreTableExists(t *testing.T) {
t.Parallel()
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
check.NoError(t, err)
p, err := goose.NewProvider("", newDB(t), newFsys(),
goose.WithStore(&customStoreSQLite3{store}),
)
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
}
type customStoreSQLite3 struct {
database.Store
}
func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) {
q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists`
var exists bool
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
return false, err
}
return exists, nil
}
func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
var gotVersion int64
if err := db.QueryRow(

1
up.go
View File

@ -181,7 +181,6 @@ func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...Options
}
// listAllDBVersions returns a list of all migrations, ordered ascending.
// TODO(mf): fairly cheap, but a nice-to-have is pagination support.
func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) {
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil {