diff --git a/database/store.go b/database/store.go index 107a1ee..86ca03b 100644 --- a/database/store.go +++ b/database/store.go @@ -21,7 +21,8 @@ type Store interface { Tablename() string // CreateVersionTable creates the version table. This table is used to record applied - // migrations. + // migrations. When creating the table, the implementation must also insert a row for the + // initial version (0). CreateVersionTable(ctx context.Context, db DBTxConn) error // Insert inserts a version id into the version table. @@ -35,7 +36,9 @@ type Store interface { GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) // ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If - // there are no migrations, return empty slice with no error. + // there are no migrations, return empty slice with no error. Typically this method will return + // at least one migration, because the initial version (0) is always inserted into the version + // table when it is created. ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error) // TODO(mf): remove this method once the Provider is public and a custom Store can be used. diff --git a/lock/postgres.go b/lock/postgres.go index 25405b2..97b7bae 100644 --- a/lock/postgres.go +++ b/lock/postgres.go @@ -89,12 +89,12 @@ func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Con return nil } /* - TODO(mf): provide users with some documentation on how they can unlock the session + docs(md): provide users with some documentation on how they can unlock the session manually. This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will - release all session level advisory locks held by the current session. This function is - implicitly invoked at session end, even if the client disconnects ungracefully. + release all session level advisory locks held by the current session. It is implicitly + invoked at session end, even if the client disconnects ungracefully. Here is output from a session that has a lock held: diff --git a/migration.go b/migration.go index 378e626..d1fbd7d 100644 --- a/migration.go +++ b/migration.go @@ -391,3 +391,8 @@ func truncateDuration(d time.Duration) time.Duration { } return d } + +// ref returns a string that identifies the migration. This is used for logging and error messages. +func (m *Migration) ref() string { + return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version) +} diff --git a/provider.go b/provider.go index 4be894d..41b22e2 100644 --- a/provider.go +++ b/provider.go @@ -297,7 +297,7 @@ func (p *Provider) up( } conn, cleanup, err := p.initialize(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) @@ -339,7 +339,7 @@ func (p *Provider) down( ) (_ []*MigrationResult, retErr error) { conn, cleanup, err := p.initialize(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) @@ -397,7 +397,7 @@ func (p *Provider) apply( } conn, cleanup, err := p.initialize(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) @@ -422,7 +422,7 @@ func (p *Provider) apply( func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { conn, cleanup, err := p.initialize(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) @@ -455,7 +455,7 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) { conn, cleanup, err := p.initialize(ctx) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to initialize: %w", err) } defer func() { retErr = multierr.Append(retErr, cleanup()) diff --git a/provider_collect_test.go b/provider_collect_test.go index d07004a..f537227 100644 --- a/provider_collect_test.go +++ b/provider_collect_test.go @@ -313,8 +313,8 @@ func TestCheckMissingMigrations(t *testing.T) { } got := checkMissingMigrations(dbMigrations, fsMigrations) check.Number(t, len(got), 2) - check.Number(t, got[0].versionID, 2) - check.Number(t, got[1].versionID, 6) + check.Number(t, got[0], 2) + check.Number(t, got[1], 6) // Sanity check. check.Number(t, len(checkMissingMigrations(nil, nil)), 0) @@ -333,8 +333,8 @@ func TestCheckMissingMigrations(t *testing.T) { } got := checkMissingMigrations(dbMigrations, fsMigrations) check.Number(t, len(got), 2) - check.Number(t, got[0].versionID, 3) - check.Number(t, got[1].versionID, 4) + check.Number(t, got[0], 3) + check.Number(t, got[1], 4) }) } diff --git a/provider_errors.go b/provider_errors.go index 3004641..72c8d67 100644 --- a/provider_errors.go +++ b/provider_errors.go @@ -3,7 +3,6 @@ package goose import ( "errors" "fmt" - "path/filepath" ) var ( @@ -32,9 +31,8 @@ type PartialError struct { } func (e *PartialError) Error() string { - filename := "(file unknown)" - if e.Failed != nil && e.Failed.Source.Path != "" { - filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path)) - } - return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err) + return fmt.Sprintf( + "partial migration error (type:%s,version:%d): %v", + e.Failed.Source.Type, e.Failed.Source.Version, e.Err, + ) } diff --git a/provider_run.go b/provider_run.go index 28fe531..e69f94d 100644 --- a/provider_run.go +++ b/provider_run.go @@ -7,6 +7,7 @@ import ( "fmt" "io/fs" "sort" + "strconv" "strings" "time" @@ -43,7 +44,7 @@ func (p *Provider) resolveUpMigrations( if len(missingMigrations) > 0 && !p.cfg.allowMissing { var collected []string for _, v := range missingMigrations { - collected = append(collected, fmt.Sprintf("%d", v.versionID)) + collected = append(collected, strconv.FormatInt(v, 10)) } msg := "migration" if len(collected) > 1 { @@ -53,8 +54,8 @@ func (p *Provider) resolveUpMigrations( len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","), ) } - for _, v := range missingMigrations { - m, err := p.getMigration(v.versionID) + for _, missingVersion := range missingMigrations { + m, err := p.getMigration(missingVersion) if err != nil { return nil, err } @@ -141,7 +142,7 @@ func (p *Provider) runMigrations( for _, m := range apply { if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil { - return nil, err + return nil, fmt.Errorf("failed to prepare migration %s: %w", m.ref(), err) } } @@ -301,11 +302,26 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err } func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { - // feat(mf): this is where we can check if the version table exists instead of trying to fetch - // from a table that may not exist. https://github.com/pressly/goose/issues/461 - res, err := p.store.GetMigration(ctx, conn, 0) - if err == nil && res != nil { - return nil + // existor is an interface that extends the Store interface with a method to check if the + // version table exists. This API is not stable and may change in the future. + type existor interface { + TableExists(context.Context, database.DBTxConn, string) (bool, error) + } + if e, ok := p.store.(existor); ok { + exists, err := e.TableExists(ctx, conn, p.store.Tablename()) + if err != nil { + return fmt.Errorf("failed to check if version table exists: %w", err) + } + if exists { + return nil + } + } else { + // feat(mf): this is where we can check if the version table exists instead of trying to fetch + // from a table that may not exist. https://github.com/pressly/goose/issues/461 + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil + } } return beginTx(ctx, conn, func(tx *sql.Tx) error { if err := p.store.CreateVersionTable(ctx, tx); err != nil { @@ -318,16 +334,12 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE }) } -type missingMigration struct { - versionID int64 -} - // checkMissingMigrations returns a list of migrations that are missing from the database. A missing // migration is one that has a version less than the max version in the database. func checkMissingMigrations( dbMigrations []*database.ListMigrationsResult, fsMigrations []*Migration, -) []missingMigration { +) []int64 { existing := make(map[int64]bool) var dbMaxVersion int64 for _, m := range dbMigrations { @@ -336,17 +348,14 @@ func checkMissingMigrations( dbMaxVersion = m.Version } } - var missing []missingMigration + var missing []int64 for _, m := range fsMigrations { - version := m.Version - if !existing[version] && version < dbMaxVersion { - missing = append(missing, missingMigration{ - versionID: version, - }) + if !existing[m.Version] && m.Version < dbMaxVersion { + missing = append(missing, m.Version) } } sort.Slice(missing, func(i, j int) bool { - return missing[i].versionID < missing[j].versionID + return missing[i] < missing[j] }) return missing } diff --git a/provider_run_test.go b/provider_run_test.go index 9b353c5..f30e09e 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -17,6 +17,7 @@ import ( "testing/fstest" "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/testdb" "github.com/pressly/goose/v3/lock" @@ -31,7 +32,7 @@ func TestProviderRun(t *testing.T) { check.NoError(t, db.Close()) _, err := p.Up(context.Background()) check.HasError(t, err) - check.Equal(t, err.Error(), "sql: database is closed") + check.Equal(t, err.Error(), "failed to initialize: sql: database is closed") }) t.Run("ping_and_close", func(t *testing.T) { p, _ := newProviderWithDB(t) @@ -324,7 +325,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3'); check.NoError(t, err) _, err = p.Up(ctx) check.HasError(t, err) - check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)") + check.Contains(t, err.Error(), "partial migration error (type:sql,version:2)") var expected *goose.PartialError check.Bool(t, errors.As(err, &expected), true) // Check Err field @@ -723,6 +724,32 @@ func TestSQLiteSharedCache(t *testing.T) { }) } +func TestCustomStoreTableExists(t *testing.T) { + t.Parallel() + + store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename) + check.NoError(t, err) + p, err := goose.NewProvider("", newDB(t), newFsys(), + goose.WithStore(&customStoreSQLite3{store}), + ) + check.NoError(t, err) + _, err = p.Up(context.Background()) + check.NoError(t, err) +} + +type customStoreSQLite3 struct { + database.Store +} + +func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) { + q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists` + var exists bool + if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil { + return false, err + } + return exists, nil +} + func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) { var gotVersion int64 if err := db.QueryRow( diff --git a/up.go b/up.go index ecee35f..4f75620 100644 --- a/up.go +++ b/up.go @@ -181,7 +181,6 @@ func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...Options } // listAllDBVersions returns a list of all migrations, ordered ascending. -// TODO(mf): fairly cheap, but a nice-to-have is pagination support. func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) { dbMigrations, err := store.ListMigrations(ctx, db, TableName()) if err != nil {