goose/database/dialects.go
2025-08-04 10:37:35 -04:00

211 lines
6.1 KiB
Go

package database
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3/database/dialect"
"github.com/pressly/goose/v3/internal/dialects"
)
// Dialect is the type of database dialect.
type Dialect string
const (
DialectCustom Dialect = ""
DialectClickHouse Dialect = "clickhouse"
DialectAuroraDSQL Dialect = "dsql"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectStarrocks Dialect = "starrocks"
DialectTiDB Dialect = "tidb"
DialectTurso Dialect = "turso"
DialectVertica Dialect = "vertica"
DialectYdB Dialect = "ydb"
)
// NewStore returns a new [Store] implementation for the given dialect.
func NewStore(d Dialect, tableName string) (Store, error) {
if d == DialectCustom {
return nil, errors.New("custom dialect is not supported")
}
lookup := map[Dialect]dialect.Querier{
DialectClickHouse: dialects.NewClickhouse(),
DialectAuroraDSQL: dialects.NewAuroraDSQL(),
DialectMSSQL: dialects.NewSqlserver(),
DialectMySQL: dialects.NewMysql(),
DialectPostgres: dialects.NewPostgres(),
DialectRedshift: dialects.NewRedshift(),
DialectSQLite3: dialects.NewSqlite3(),
DialectStarrocks: dialects.NewStarrocks(),
DialectTiDB: dialects.NewTidb(),
DialectTurso: dialects.NewTurso(),
DialectVertica: dialects.NewVertica(),
DialectYdB: dialects.NewYDB(),
}
querier, ok := lookup[d]
if !ok {
return nil, fmt.Errorf("unknown dialect: %q", d)
}
return NewStoreFromQuerier(tableName, querier)
}
// NewStoreFromQuerier returns a new [Store] implementation for the given querier.
//
// Most of the time you should use [NewStore] instead of this function, as it will automatically
// create a dialect-specific querier for you. This function is useful if you want to use a custom
// querier that is not part of the standard dialects.
func NewStoreFromQuerier(tableName string, querier dialect.Querier) (Store, error) {
if tableName == "" {
return nil, errors.New("table name must not be empty")
}
if querier == nil {
return nil, errors.New("querier must not be nil")
}
return &store{
tableName: tableName,
querier: newQueryController(querier),
}, nil
}
type store struct {
tableName string
querier *queryController
}
var _ Store = (*store)(nil)
func (s *store) Tablename() string {
return s.tableName
}
func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
q := s.querier.CreateTable(s.tableName)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", s.tableName, err)
}
return nil
}
func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
q := s.querier.InsertVersion(s.tableName)
if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
}
return nil
}
func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
q := s.querier.DeleteVersion(s.tableName)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
}
return nil
}
func (s *store) GetMigration(
ctx context.Context,
db DBTxConn,
version int64,
) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tableName)
var result GetMigrationResult
if err := db.QueryRowContext(ctx, q, version).Scan(
&result.Timestamp,
&result.IsApplied,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
}
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return &result, nil
}
func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
q := s.querier.GetLatestVersion(s.tableName)
var version sql.NullInt64
if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
return -1, fmt.Errorf("failed to get latest version: %w", err)
}
if !version.Valid {
return -1, fmt.Errorf("latest %w", ErrVersionNotFound)
}
return version.Int64, nil
}
func (s *store) ListMigrations(
ctx context.Context,
db DBTxConn,
) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tableName)
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("failed to list migrations: %w", err)
}
defer rows.Close()
var migrations []*ListMigrationsResult
for rows.Next() {
var result ListMigrationsResult
if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
}
migrations = append(migrations, &result)
}
if err := rows.Err(); err != nil {
return nil, err
}
return migrations, nil
}
//
//
//
// Additional methods that are not part of the core Store interface, but are extended by the
// [controller.StoreController] type.
//
//
//
func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
q := s.querier.TableExists(s.tableName)
if q == "" {
return false, errors.ErrUnsupported
}
var exists bool
// Note, we do not pass the table name as an argument to the query, as the query should be
// pre-defined by the dialect.
if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil {
return false, fmt.Errorf("failed to check if table exists: %w", err)
}
return exists, nil
}
var _ dialect.Querier = (*queryController)(nil)
type queryController struct{ dialect.Querier }
// newQueryController returns a new QueryController that wraps the given Querier.
func newQueryController(querier dialect.Querier) *queryController {
return &queryController{Querier: querier}
}
// Optional methods
// TableExists returns the SQL query string to check if the version table exists. If the Querier
// does not implement this method, it will return an empty string.
//
// Returns a boolean value.
func (c *queryController) TableExists(tableName string) string {
if t, ok := c.Querier.(interface{ TableExists(string) string }); ok {
return t.TableExists(tableName)
}
return ""
}