mirror of
https://github.com/pressly/goose.git
synced 2025-05-31 11:42:04 +00:00
refactor: add TableExists assertion support and improve docs (#641)
This commit is contained in:
parent
c5e0d3cffc
commit
9e6ef20c4f
@ -21,7 +21,8 @@ type Store interface {
|
||||
Tablename() string
|
||||
|
||||
// CreateVersionTable creates the version table. This table is used to record applied
|
||||
// migrations.
|
||||
// migrations. When creating the table, the implementation must also insert a row for the
|
||||
// initial version (0).
|
||||
CreateVersionTable(ctx context.Context, db DBTxConn) error
|
||||
|
||||
// Insert inserts a version id into the version table.
|
||||
@ -35,7 +36,9 @@ type Store interface {
|
||||
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
|
||||
|
||||
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
|
||||
// there are no migrations, return empty slice with no error.
|
||||
// there are no migrations, return empty slice with no error. Typically this method will return
|
||||
// at least one migration, because the initial version (0) is always inserted into the version
|
||||
// table when it is created.
|
||||
ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error)
|
||||
|
||||
// TODO(mf): remove this method once the Provider is public and a custom Store can be used.
|
||||
|
@ -89,12 +89,12 @@ func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Con
|
||||
return nil
|
||||
}
|
||||
/*
|
||||
TODO(mf): provide users with some documentation on how they can unlock the session
|
||||
docs(md): provide users with some documentation on how they can unlock the session
|
||||
manually.
|
||||
|
||||
This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will
|
||||
release all session level advisory locks held by the current session. This function is
|
||||
implicitly invoked at session end, even if the client disconnects ungracefully.
|
||||
release all session level advisory locks held by the current session. It is implicitly
|
||||
invoked at session end, even if the client disconnects ungracefully.
|
||||
|
||||
Here is output from a session that has a lock held:
|
||||
|
||||
|
@ -391,3 +391,8 @@ func truncateDuration(d time.Duration) time.Duration {
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// ref returns a string that identifies the migration. This is used for logging and error messages.
|
||||
func (m *Migration) ref() string {
|
||||
return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
|
||||
}
|
||||
|
10
provider.go
10
provider.go
@ -297,7 +297,7 @@ func (p *Provider) up(
|
||||
}
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to initialize: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
@ -339,7 +339,7 @@ func (p *Provider) down(
|
||||
) (_ []*MigrationResult, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to initialize: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
@ -397,7 +397,7 @@ func (p *Provider) apply(
|
||||
}
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to initialize: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
@ -422,7 +422,7 @@ func (p *Provider) apply(
|
||||
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to initialize: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
@ -455,7 +455,7 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err
|
||||
func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) {
|
||||
conn, cleanup, err := p.initialize(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("failed to initialize: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
retErr = multierr.Append(retErr, cleanup())
|
||||
|
@ -313,8 +313,8 @@ func TestCheckMissingMigrations(t *testing.T) {
|
||||
}
|
||||
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
||||
check.Number(t, len(got), 2)
|
||||
check.Number(t, got[0].versionID, 2)
|
||||
check.Number(t, got[1].versionID, 6)
|
||||
check.Number(t, got[0], 2)
|
||||
check.Number(t, got[1], 6)
|
||||
|
||||
// Sanity check.
|
||||
check.Number(t, len(checkMissingMigrations(nil, nil)), 0)
|
||||
@ -333,8 +333,8 @@ func TestCheckMissingMigrations(t *testing.T) {
|
||||
}
|
||||
got := checkMissingMigrations(dbMigrations, fsMigrations)
|
||||
check.Number(t, len(got), 2)
|
||||
check.Number(t, got[0].versionID, 3)
|
||||
check.Number(t, got[1].versionID, 4)
|
||||
check.Number(t, got[0], 3)
|
||||
check.Number(t, got[1], 4)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,6 @@ package goose
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -32,9 +31,8 @@ type PartialError struct {
|
||||
}
|
||||
|
||||
func (e *PartialError) Error() string {
|
||||
filename := "(file unknown)"
|
||||
if e.Failed != nil && e.Failed.Source.Path != "" {
|
||||
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path))
|
||||
}
|
||||
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
|
||||
return fmt.Sprintf(
|
||||
"partial migration error (type:%s,version:%d): %v",
|
||||
e.Failed.Source.Type, e.Failed.Source.Version, e.Err,
|
||||
)
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -43,7 +44,7 @@ func (p *Provider) resolveUpMigrations(
|
||||
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
|
||||
var collected []string
|
||||
for _, v := range missingMigrations {
|
||||
collected = append(collected, fmt.Sprintf("%d", v.versionID))
|
||||
collected = append(collected, strconv.FormatInt(v, 10))
|
||||
}
|
||||
msg := "migration"
|
||||
if len(collected) > 1 {
|
||||
@ -53,8 +54,8 @@ func (p *Provider) resolveUpMigrations(
|
||||
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
|
||||
)
|
||||
}
|
||||
for _, v := range missingMigrations {
|
||||
m, err := p.getMigration(v.versionID)
|
||||
for _, missingVersion := range missingMigrations {
|
||||
m, err := p.getMigration(missingVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -141,7 +142,7 @@ func (p *Provider) runMigrations(
|
||||
|
||||
for _, m := range apply {
|
||||
if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to prepare migration %s: %w", m.ref(), err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,11 +302,26 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
|
||||
}
|
||||
|
||||
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
|
||||
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
|
||||
// from a table that may not exist. https://github.com/pressly/goose/issues/461
|
||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||
if err == nil && res != nil {
|
||||
return nil
|
||||
// existor is an interface that extends the Store interface with a method to check if the
|
||||
// version table exists. This API is not stable and may change in the future.
|
||||
type existor interface {
|
||||
TableExists(context.Context, database.DBTxConn, string) (bool, error)
|
||||
}
|
||||
if e, ok := p.store.(existor); ok {
|
||||
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if version table exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
|
||||
// from a table that may not exist. https://github.com/pressly/goose/issues/461
|
||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||
if err == nil && res != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
|
||||
@ -318,16 +334,12 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE
|
||||
})
|
||||
}
|
||||
|
||||
type missingMigration struct {
|
||||
versionID int64
|
||||
}
|
||||
|
||||
// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
|
||||
// migration is one that has a version less than the max version in the database.
|
||||
func checkMissingMigrations(
|
||||
dbMigrations []*database.ListMigrationsResult,
|
||||
fsMigrations []*Migration,
|
||||
) []missingMigration {
|
||||
) []int64 {
|
||||
existing := make(map[int64]bool)
|
||||
var dbMaxVersion int64
|
||||
for _, m := range dbMigrations {
|
||||
@ -336,17 +348,14 @@ func checkMissingMigrations(
|
||||
dbMaxVersion = m.Version
|
||||
}
|
||||
}
|
||||
var missing []missingMigration
|
||||
var missing []int64
|
||||
for _, m := range fsMigrations {
|
||||
version := m.Version
|
||||
if !existing[version] && version < dbMaxVersion {
|
||||
missing = append(missing, missingMigration{
|
||||
versionID: version,
|
||||
})
|
||||
if !existing[m.Version] && m.Version < dbMaxVersion {
|
||||
missing = append(missing, m.Version)
|
||||
}
|
||||
}
|
||||
sort.Slice(missing, func(i, j int) bool {
|
||||
return missing[i].versionID < missing[j].versionID
|
||||
return missing[i] < missing[j]
|
||||
})
|
||||
return missing
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"github.com/pressly/goose/v3/database"
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/testdb"
|
||||
"github.com/pressly/goose/v3/lock"
|
||||
@ -31,7 +32,7 @@ func TestProviderRun(t *testing.T) {
|
||||
check.NoError(t, db.Close())
|
||||
_, err := p.Up(context.Background())
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "sql: database is closed")
|
||||
check.Equal(t, err.Error(), "failed to initialize: sql: database is closed")
|
||||
})
|
||||
t.Run("ping_and_close", func(t *testing.T) {
|
||||
p, _ := newProviderWithDB(t)
|
||||
@ -324,7 +325,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
||||
check.NoError(t, err)
|
||||
_, err = p.Up(ctx)
|
||||
check.HasError(t, err)
|
||||
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
|
||||
check.Contains(t, err.Error(), "partial migration error (type:sql,version:2)")
|
||||
var expected *goose.PartialError
|
||||
check.Bool(t, errors.As(err, &expected), true)
|
||||
// Check Err field
|
||||
@ -723,6 +724,32 @@ func TestSQLiteSharedCache(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCustomStoreTableExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
p, err := goose.NewProvider("", newDB(t), newFsys(),
|
||||
goose.WithStore(&customStoreSQLite3{store}),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
_, err = p.Up(context.Background())
|
||||
check.NoError(t, err)
|
||||
}
|
||||
|
||||
type customStoreSQLite3 struct {
|
||||
database.Store
|
||||
}
|
||||
|
||||
func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) {
|
||||
q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists`
|
||||
var exists bool
|
||||
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
|
||||
var gotVersion int64
|
||||
if err := db.QueryRow(
|
||||
|
1
up.go
1
up.go
@ -181,7 +181,6 @@ func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...Options
|
||||
}
|
||||
|
||||
// listAllDBVersions returns a list of all migrations, ordered ascending.
|
||||
// TODO(mf): fairly cheap, but a nice-to-have is pagination support.
|
||||
func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) {
|
||||
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
|
||||
if err != nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user