Implement GetLatestVersion for all natively supported dialects (#758)

pull/759/head
Michael Fridman 2024-04-26 14:06:20 -04:00 committed by GitHub
parent 2d33f01788
commit 272603b047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 135 additions and 57 deletions

View File

@ -9,6 +9,14 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- Add `CheckPending` method to `goose.Provider` to check if there are pending migrations, returns - Add `CheckPending` method to `goose.Provider` to check if there are pending migrations, returns
the current (max db) version and the latest (max file) version. (#756) the current (max db) version and the latest (max file) version. (#756)
- Clarify `GetLatestVersion` method MUST return `ErrVersionNotFound` if no latest migration is
found. Previously it was returning a -1 and nil error, which was inconsistent with the rest of the
API surface.
- Add `GetLatestVersion` implementations to all existing dialects. This is an optimization to avoid
loading all migrations when only the latest version is needed. This uses the `max` function in SQL
to get the latest version_id irrespective of the order of applied migrations.
- Refactor existing portions of the code to use the new `GetLatestVersion` method.
## [v3.20.0] ## [v3.20.0]

View File

@ -110,7 +110,15 @@ func (s *store) GetMigration(
} }
func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) { func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
return -1, errors.New("not implemented") q := s.querier.GetLatestVersion(s.tablename)
var version sql.NullInt64
if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
return -1, fmt.Errorf("failed to get latest version: %w", err)
}
if !version.Valid {
return -1, fmt.Errorf("latest %w", ErrVersionNotFound)
}
return version.Int64, nil
} }
func (s *store) ListMigrations( func (s *store) ListMigrations(

View File

@ -7,8 +7,12 @@ import (
) )
var ( var (
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found. // ErrVersionNotFound must be returned by [GetMigration] or [GetLatestVersion] when a migration
// does not exist.
ErrVersionNotFound = errors.New("version not found") ErrVersionNotFound = errors.New("version not found")
// ErrNotImplemented must be returned by methods that are not implemented.
ErrNotImplemented = errors.New("not implemented")
) )
// Store is an interface that defines methods for tracking and managing migrations. It is used by // Store is an interface that defines methods for tracking and managing migrations. It is used by
@ -34,7 +38,7 @@ type Store interface {
// version is not found, this method must return [ErrVersionNotFound]. // version is not found, this method must return [ErrVersionNotFound].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
// GetLatestVersion retrieves the last applied migration version. If no migrations exist, this // GetLatestVersion retrieves the last applied migration version. If no migrations exist, this
// method must return -1 and no error. // method must return [ErrVersionNotFound].
GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error)
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If // ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error. Typically this method will return // there are no migrations, return empty slice with no error. Typically this method will return

View File

@ -95,6 +95,9 @@ func testStore(
if alreadyExists != nil { if alreadyExists != nil {
alreadyExists(t, err) alreadyExists(t, err)
} }
// Get the latest version. There should be none.
_, err = store.GetLatestVersion(ctx, db)
check.IsError(t, err, database.ErrVersionNotFound)
// List migrations. There should be none. // List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error { err = runConn(ctx, db, func(conn *sql.Conn) error {
@ -108,7 +111,12 @@ func testStore(
// Insert 5 migrations in addition to the zero migration. // Insert 5 migrations in addition to the zero migration.
for i := 0; i < 6; i++ { for i := 0; i < 6; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error { err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)}) err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, int64(i))
return nil
}) })
check.NoError(t, err) check.NoError(t, err)
} }
@ -129,7 +137,12 @@ func testStore(
// Delete 3 migrations backwards // Delete 3 migrations backwards
for i := 5; i >= 3; i-- { for i := 5; i >= 3; i-- {
err = runConn(ctx, db, func(conn *sql.Conn) error { err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.Delete(ctx, conn, int64(i)) err := store.Delete(ctx, conn, int64(i))
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, int64(i-1))
return nil
}) })
check.NoError(t, err) check.NoError(t, err)
} }
@ -163,17 +176,29 @@ func testStore(
// 1. *sql.Tx // 1. *sql.Tx
err = runTx(ctx, db, func(tx *sql.Tx) error { err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.Delete(ctx, tx, 2) err := store.Delete(ctx, tx, 2)
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, tx)
check.NoError(t, err)
check.Number(t, latest, 1)
return nil
}) })
check.NoError(t, err) check.NoError(t, err)
// 2. *sql.Conn // 2. *sql.Conn
err = runConn(ctx, db, func(conn *sql.Conn) error { err = runConn(ctx, db, func(conn *sql.Conn) error {
return store.Delete(ctx, conn, 1) err := store.Delete(ctx, conn, 1)
check.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
check.NoError(t, err)
check.Number(t, latest, 0)
return nil
}) })
check.NoError(t, err) check.NoError(t, err)
// 3. *sql.DB // 3. *sql.DB
err = store.Delete(ctx, db, 0) err = store.Delete(ctx, db, 0)
check.NoError(t, err) check.NoError(t, err)
_, err = store.GetLatestVersion(ctx, db)
check.IsError(t, err, database.ErrVersionNotFound)
// List migrations. There should be none. // List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error { err = runConn(ctx, db, func(conn *sql.Conn) error {

View File

@ -37,3 +37,8 @@ func (c *Clickhouse) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC` q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (c *Clickhouse) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -1,28 +1,27 @@
package dialectquery package dialectquery
// Querier is the interface that wraps the basic methods to create a dialect // Querier is the interface that wraps the basic methods to create a dialect specific query.
// specific query.
type Querier interface { type Querier interface {
// CreateTable returns the SQL query string to create the db version table. // CreateTable returns the SQL query string to create the db version table.
CreateTable(tableName string) string CreateTable(tableName string) string
// InsertVersion returns the SQL query string to insert a new version into // InsertVersion returns the SQL query string to insert a new version into the db version table.
// the db version table.
InsertVersion(tableName string) string InsertVersion(tableName string) string
// DeleteVersion returns the SQL query string to delete a version from // DeleteVersion returns the SQL query string to delete a version from the db version table.
// the db version table.
DeleteVersion(tableName string) string DeleteVersion(tableName string) string
// GetMigrationByVersion returns the SQL query string to get a single // GetMigrationByVersion returns the SQL query string to get a single migration by version.
// migration by version.
// //
// The query should return the timestamp and is_applied columns. // The query should return the timestamp and is_applied columns.
GetMigrationByVersion(tableName string) string GetMigrationByVersion(tableName string) string
// ListMigrations returns the SQL query string to list all migrations in // ListMigrations returns the SQL query string to list all migrations in descending order by id.
// descending order by id.
// //
// The query should return the version_id and is_applied columns. // The query should return the version_id and is_applied columns.
ListMigrations(tableName string) string ListMigrations(tableName string) string
// GetLatestVersion returns the SQL query string to get the last version_id from the db version
// table. Returns a nullable int64 value.
GetLatestVersion(tableName string) string
} }

View File

@ -36,3 +36,8 @@ func (m *Mysql) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (m *Mysql) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -36,3 +36,8 @@ func (p *Postgres) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (p *Postgres) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -36,3 +36,8 @@ func (r *Redshift) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (r *Redshift) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -35,3 +35,8 @@ func (s *Sqlite3) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlite3) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -35,3 +35,8 @@ func (s *Sqlserver) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC` q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlserver) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -36,3 +36,8 @@ func (t *Tidb) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (t *Tidb) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -36,3 +36,8 @@ func (v *Vertica) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (v *Vertica) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -46,3 +46,8 @@ func (c *Ydb) ListMigrations(tableName string) string {
FROM %s ORDER BY __discard_column_tstamp DESC` FROM %s ORDER BY __discard_column_tstamp DESC`
return fmt.Sprintf(q, tableName) return fmt.Sprintf(q, tableName)
} }
func (c *Ydb) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"math" "math"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -167,6 +166,10 @@ func (p *Provider) HasPending(ctx context.Context) (bool, error) {
// //
// Note, this method will not use a SessionLocker if one is configured. This allows callers to check // Note, this method will not use a SessionLocker if one is configured. This allows callers to check
// for pending migrations without blocking or being blocked by other operations. // for pending migrations without blocking or being blocked by other operations.
//
// If out-of-order migrations are enabled this method is not suitable for checking pending
// migrations because it ONLY returns the highest version in the database. Instead, use the
// [HasPending] method.
func (p *Provider) CheckPending(ctx context.Context) (current, target int64, err error) { func (p *Provider) CheckPending(ctx context.Context) (current, target int64, err error) {
return p.checkPending(ctx) return p.checkPending(ctx)
} }
@ -483,30 +486,22 @@ func (p *Provider) checkPending(ctx context.Context) (current, target int64, ret
retErr = multierr.Append(retErr, cleanup()) retErr = multierr.Append(retErr, cleanup())
}() }()
target = p.migrations[len(p.migrations)-1].Version
// If versioning is disabled, we always have pending migrations and the target version is the // If versioning is disabled, we always have pending migrations and the target version is the
// last migration. // last migration.
if p.cfg.disableVersioning { if p.cfg.disableVersioning {
return -1, p.migrations[len(p.migrations)-1].Version, nil return -1, target, nil
} }
// optimize(mf): we should only fetch the max version from the database, no need to fetch all
// migrations only to get the max version when we're not using out-of-order migrations. current, err = p.store.GetLatestVersion(ctx, conn)
res, err := p.store.ListMigrations(ctx, conn)
if err != nil { if err != nil {
return -1, -1, err if errors.Is(err, database.ErrVersionNotFound) {
return -1, target, errMissingZeroVersion
}
return -1, target, err
} }
dbVersions := make([]int64, 0, len(res)) return current, target, nil
for _, m := range res {
dbVersions = append(dbVersions, m.Version)
}
sort.Slice(dbVersions, func(i, j int) bool {
return dbVersions[i] < dbVersions[j]
})
if len(dbVersions) == 0 {
return -1, -1, errMissingZeroVersion
} else {
current = dbVersions[len(dbVersions)-1]
}
return current, p.migrations[len(p.migrations)-1].Version, nil
} }
func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
@ -523,7 +518,8 @@ func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
return true, nil return true, nil
} }
if p.cfg.allowMissing { if p.cfg.allowMissing {
// List all migrations from the database. // List all migrations from the database. We cannot optimize this because we need to check
// that EVERY migration known the provider has been applied.
dbMigrations, err := p.store.ListMigrations(ctx, conn) dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil { if err != nil {
return false, err return false, err
@ -544,16 +540,16 @@ func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
} }
return false, nil return false, nil
} }
// If out-of-order migrations are not allowed, we can optimize this by only checking whether the // If out-of-order migrations are not allowed, we can optimize this by only checking the latest
// last migration the provider knows about is applied. // version in the database against the latest migration version.
last := p.migrations[len(p.migrations)-1] current, err := p.store.GetLatestVersion(ctx, conn)
if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil { if err != nil {
if errors.Is(err, database.ErrVersionNotFound) { if errors.Is(err, database.ErrVersionNotFound) {
return true, nil return false, errMissingZeroVersion
} }
return false, err return false, err
} }
return false, nil return current < p.migrations[len(p.migrations)-1].Version, nil
} }
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
@ -591,9 +587,6 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err
// getDBMaxVersion returns the highest version recorded in the database, regardless of the order in // getDBMaxVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. conn may be nil, in which case a connection is initialized. // which migrations were applied. conn may be nil, in which case a connection is initialized.
//
// optimize(mf): we should only fetch the max version from the database, no need to fetch all
// migrations only to get the max version. This means expanding the Store interface.
func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) { func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64, retErr error) {
if conn == nil { if conn == nil {
var cleanup func() error var cleanup func() error
@ -606,17 +599,13 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64
retErr = multierr.Append(retErr, cleanup()) retErr = multierr.Append(retErr, cleanup())
}() }()
} }
res, err := p.store.ListMigrations(ctx, conn)
latest, err := p.store.GetLatestVersion(ctx, conn)
if err != nil { if err != nil {
return 0, err if errors.Is(err, database.ErrVersionNotFound) {
return 0, errMissingZeroVersion
}
return -1, err
} }
if len(res) == 0 { return latest, nil
return 0, errMissingZeroVersion
}
// Sort in descending order.
sort.Slice(res, func(i, j int) bool {
return res[i].Version > res[j].Version
})
// Return the highest version.
return res[0].Version, nil
} }