mirror of https://github.com/pressly/goose.git
173 lines
4.8 KiB
Go
173 lines
4.8 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/pressly/goose/v3/internal/dialect/dialectquery"
|
|
)
|
|
|
|
// Dialect is the type of database dialect.
|
|
type Dialect string
|
|
|
|
const (
|
|
DialectClickHouse Dialect = "clickhouse"
|
|
DialectMSSQL Dialect = "mssql"
|
|
DialectMySQL Dialect = "mysql"
|
|
DialectPostgres Dialect = "postgres"
|
|
DialectRedshift Dialect = "redshift"
|
|
DialectSQLite3 Dialect = "sqlite3"
|
|
DialectTiDB Dialect = "tidb"
|
|
DialectTurso Dialect = "turso"
|
|
DialectVertica Dialect = "vertica"
|
|
DialectYdB Dialect = "ydb"
|
|
DialectStarrocks Dialect = "starrocks"
|
|
)
|
|
|
|
// NewStore returns a new [Store] implementation for the given dialect.
|
|
func NewStore(dialect Dialect, tablename string) (Store, error) {
|
|
if tablename == "" {
|
|
return nil, errors.New("table name must not be empty")
|
|
}
|
|
if dialect == "" {
|
|
return nil, errors.New("dialect must not be empty")
|
|
}
|
|
lookup := map[Dialect]dialectquery.Querier{
|
|
DialectClickHouse: &dialectquery.Clickhouse{},
|
|
DialectMSSQL: &dialectquery.Sqlserver{},
|
|
DialectMySQL: &dialectquery.Mysql{},
|
|
DialectPostgres: &dialectquery.Postgres{},
|
|
DialectRedshift: &dialectquery.Redshift{},
|
|
DialectSQLite3: &dialectquery.Sqlite3{},
|
|
DialectTiDB: &dialectquery.Tidb{},
|
|
DialectVertica: &dialectquery.Vertica{},
|
|
DialectYdB: &dialectquery.Ydb{},
|
|
DialectTurso: &dialectquery.Turso{},
|
|
DialectStarrocks: &dialectquery.Starrocks{},
|
|
}
|
|
querier, ok := lookup[dialect]
|
|
if !ok {
|
|
return nil, fmt.Errorf("unknown dialect: %q", dialect)
|
|
}
|
|
return &store{
|
|
tablename: tablename,
|
|
querier: dialectquery.NewQueryController(querier),
|
|
}, nil
|
|
}
|
|
|
|
type store struct {
|
|
tablename string
|
|
querier *dialectquery.QueryController
|
|
}
|
|
|
|
var _ Store = (*store)(nil)
|
|
|
|
func (s *store) Tablename() string {
|
|
return s.tablename
|
|
}
|
|
|
|
func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
|
|
q := s.querier.CreateTable(s.tablename)
|
|
if _, err := db.ExecContext(ctx, q); err != nil {
|
|
return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
|
|
q := s.querier.InsertVersion(s.tablename)
|
|
if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
|
|
return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
|
|
q := s.querier.DeleteVersion(s.tablename)
|
|
if _, err := db.ExecContext(ctx, q, version); err != nil {
|
|
return fmt.Errorf("failed to delete version %d: %w", version, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *store) GetMigration(
|
|
ctx context.Context,
|
|
db DBTxConn,
|
|
version int64,
|
|
) (*GetMigrationResult, error) {
|
|
q := s.querier.GetMigrationByVersion(s.tablename)
|
|
var result GetMigrationResult
|
|
if err := db.QueryRowContext(ctx, q, version).Scan(
|
|
&result.Timestamp,
|
|
&result.IsApplied,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
|
|
}
|
|
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
|
|
q := s.querier.GetLatestVersion(s.tablename)
|
|
var version sql.NullInt64
|
|
if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
|
|
return -1, fmt.Errorf("failed to get latest version: %w", err)
|
|
}
|
|
if !version.Valid {
|
|
return -1, fmt.Errorf("latest %w", ErrVersionNotFound)
|
|
}
|
|
return version.Int64, nil
|
|
}
|
|
|
|
func (s *store) ListMigrations(
|
|
ctx context.Context,
|
|
db DBTxConn,
|
|
) ([]*ListMigrationsResult, error) {
|
|
q := s.querier.ListMigrations(s.tablename)
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list migrations: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var migrations []*ListMigrationsResult
|
|
for rows.Next() {
|
|
var result ListMigrationsResult
|
|
if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
|
|
return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
|
|
}
|
|
migrations = append(migrations, &result)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return migrations, nil
|
|
}
|
|
|
|
//
|
|
//
|
|
//
|
|
// Additional methods that are not part of the core Store interface, but are extended by the
|
|
// [controller.StoreController] type.
|
|
//
|
|
//
|
|
//
|
|
|
|
func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
|
|
q := s.querier.TableExists(s.tablename)
|
|
if q == "" {
|
|
return false, errors.ErrUnsupported
|
|
}
|
|
var exists bool
|
|
// Note, we do not pass the table name as an argument to the query, as the query should be
|
|
// pre-defined by the dialect.
|
|
if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil {
|
|
return false, fmt.Errorf("failed to check if table exists: %w", err)
|
|
}
|
|
return exists, nil
|
|
}
|