mirror of https://github.com/pressly/goose.git
minor improvement
parent
7863f54bd5
commit
b8700dd367
|
@ -2,6 +2,7 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
|
@ -33,11 +34,11 @@ func NewStoreController(store Store) *StoreController {
|
|||
// TableExists is an optional method that checks if the version table exists in the database. It is
|
||||
// recommended to implement this method if the database supports it, as it can be used to optimize
|
||||
// certain operations.
|
||||
func (c *StoreController) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
|
||||
func (c *StoreController) TableExists(ctx context.Context, db *sql.Conn) (bool, error) {
|
||||
if t, ok := c.store.(interface {
|
||||
TableExists(context.Context, DBTxConn, string) (bool, error)
|
||||
TableExists(ctx context.Context, db *sql.Conn) (bool, error)
|
||||
}); ok {
|
||||
return t.TableExists(ctx, db, c.Tablename())
|
||||
return t.TableExists(ctx, db)
|
||||
}
|
||||
return false, ErrNotSupported
|
||||
}
|
||||
|
|
|
@ -138,13 +138,13 @@ func (s *store) ListMigrations(
|
|||
return migrations, nil
|
||||
}
|
||||
|
||||
func (s *store) TableExists(ctx context.Context, db DBTxConn, name string) (bool, error) {
|
||||
func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
|
||||
q := s.querier.TableExists(s.tablename)
|
||||
if q == "" {
|
||||
return false, ErrNotSupported
|
||||
}
|
||||
var exists bool
|
||||
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
|
||||
if err := db.QueryRowContext(ctx, q, s.tablename).Scan(&exists); err != nil {
|
||||
return false, fmt.Errorf("failed to check if table exists: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
|
|
|
@ -22,6 +22,8 @@ type Querier interface {
|
|||
ListMigrations(tableName string) string
|
||||
}
|
||||
|
||||
var _ Querier = (*QueryController)(nil)
|
||||
|
||||
type QueryController struct {
|
||||
querier Querier
|
||||
}
|
||||
|
|
|
@ -37,6 +37,6 @@ func (p *Postgres) ListMigrations(tableName string) string {
|
|||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (p *Postgres) TableExists(tableName string) string {
|
||||
func (p *Postgres) TableExists(_ string) string {
|
||||
return `SELECT EXISTS ( SELECT FROM pg_tables WHERE tablename = $1)`
|
||||
}
|
||||
|
|
|
@ -330,15 +330,19 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
|
|||
}
|
||||
|
||||
func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
|
||||
if ok, err := p.store.TableExists(ctx, conn); err != nil && !errors.Is(err, database.ErrNotSupported) {
|
||||
ok, err := p.store.TableExists(ctx, conn)
|
||||
if err != nil && !errors.Is(err, database.ErrNotSupported) {
|
||||
return err
|
||||
} else if ok {
|
||||
}
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
// Fall back to the default behavior if the Store does not implement TableExists.
|
||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||
if err == nil && res != nil {
|
||||
return nil
|
||||
if errors.Is(err, database.ErrNotSupported) {
|
||||
// Fall back to the default behavior if the Store does not implement TableExists.
|
||||
res, err := p.store.GetMigration(ctx, conn, 0)
|
||||
if err == nil && res != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return beginTx(ctx, conn, func(tx *sql.Tx) error {
|
||||
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
|
||||
|
|
|
@ -748,19 +748,6 @@ func TestGoMigrationPanic(t *testing.T) {
|
|||
check.Contains(t, expected.Err.Error(), wantErrString)
|
||||
}
|
||||
|
||||
func TestCustomStoreTableExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
p, err := goose.NewProvider("", newDB(t), newFsys(),
|
||||
goose.WithStore(&customStoreSQLite3{store}),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
_, err = p.Up(context.Background())
|
||||
check.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProviderApply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -774,15 +761,29 @@ func TestProviderApply(t *testing.T) {
|
|||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, goose.ErrNotApplied), true)
|
||||
}
|
||||
func TestCustomStoreTableExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
|
||||
check.NoError(t, err)
|
||||
p, err := goose.NewProvider("", newDB(t), newFsys(),
|
||||
goose.WithStore(&customStoreSQLite3{store}),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
_, err = p.Up(context.Background())
|
||||
check.NoError(t, err)
|
||||
_, err = p.Up(context.Background())
|
||||
check.NoError(t, err)
|
||||
}
|
||||
|
||||
type customStoreSQLite3 struct {
|
||||
database.Store
|
||||
}
|
||||
|
||||
func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) {
|
||||
func (s *customStoreSQLite3) TableExists(ctx context.Context, db *sql.Conn) (bool, error) {
|
||||
q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists`
|
||||
var exists bool
|
||||
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
|
||||
if err := db.QueryRowContext(ctx, q, s.Tablename()).Scan(&exists); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
|
|
Loading…
Reference in New Issue