minor improvement

mf/storecontroller
Mike Fridman 2024-04-20 23:11:26 -04:00
parent 7863f54bd5
commit b8700dd367
No known key found for this signature in database
6 changed files with 35 additions and 27 deletions

View File

@ -2,6 +2,7 @@ package database
import (
"context"
"database/sql"
"errors"
)
@ -33,11 +34,11 @@ func NewStoreController(store Store) *StoreController {
// TableExists is an optional method that checks if the version table exists in the database. It is
// recommended to implement this method if the database supports it, as it can be used to optimize
// certain operations.
func (c *StoreController) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
func (c *StoreController) TableExists(ctx context.Context, db *sql.Conn) (bool, error) {
if t, ok := c.store.(interface {
TableExists(context.Context, DBTxConn, string) (bool, error)
TableExists(ctx context.Context, db *sql.Conn) (bool, error)
}); ok {
return t.TableExists(ctx, db, c.Tablename())
return t.TableExists(ctx, db)
}
return false, ErrNotSupported
}

View File

@ -138,13 +138,13 @@ func (s *store) ListMigrations(
return migrations, nil
}
func (s *store) TableExists(ctx context.Context, db DBTxConn, name string) (bool, error) {
func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
q := s.querier.TableExists(s.tablename)
if q == "" {
return false, ErrNotSupported
}
var exists bool
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
if err := db.QueryRowContext(ctx, q, s.tablename).Scan(&exists); err != nil {
return false, fmt.Errorf("failed to check if table exists: %w", err)
}
return exists, nil

View File

@ -22,6 +22,8 @@ type Querier interface {
ListMigrations(tableName string) string
}
var _ Querier = (*QueryController)(nil)
type QueryController struct {
querier Querier
}

View File

@ -37,6 +37,6 @@ func (p *Postgres) ListMigrations(tableName string) string {
return fmt.Sprintf(q, tableName)
}
func (p *Postgres) TableExists(tableName string) string {
func (p *Postgres) TableExists(_ string) string {
return `SELECT EXISTS ( SELECT FROM pg_tables WHERE tablename = $1)`
}

View File

@ -330,15 +330,19 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
}
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
if ok, err := p.store.TableExists(ctx, conn); err != nil && !errors.Is(err, database.ErrNotSupported) {
ok, err := p.store.TableExists(ctx, conn)
if err != nil && !errors.Is(err, database.ErrNotSupported) {
return err
} else if ok {
}
if ok {
return nil
}
// Fall back to the default behavior if the Store does not implement TableExists.
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
if errors.Is(err, database.ErrNotSupported) {
// Fall back to the default behavior if the Store does not implement TableExists.
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
}
}
return beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {

View File

@ -748,19 +748,6 @@ func TestGoMigrationPanic(t *testing.T) {
check.Contains(t, expected.Err.Error(), wantErrString)
}
func TestCustomStoreTableExists(t *testing.T) {
t.Parallel()
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
check.NoError(t, err)
p, err := goose.NewProvider("", newDB(t), newFsys(),
goose.WithStore(&customStoreSQLite3{store}),
)
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
}
func TestProviderApply(t *testing.T) {
t.Parallel()
@ -774,15 +761,29 @@ func TestProviderApply(t *testing.T) {
check.HasError(t, err)
check.Bool(t, errors.Is(err, goose.ErrNotApplied), true)
}
func TestCustomStoreTableExists(t *testing.T) {
t.Parallel()
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
check.NoError(t, err)
p, err := goose.NewProvider("", newDB(t), newFsys(),
goose.WithStore(&customStoreSQLite3{store}),
)
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
}
type customStoreSQLite3 struct {
database.Store
}
func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) {
func (s *customStoreSQLite3) TableExists(ctx context.Context, db *sql.Conn) (bool, error) {
q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists`
var exists bool
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
if err := db.QueryRowContext(ctx, q, s.Tablename()).Scan(&exists); err != nil {
return false, err
}
return exists, nil