package database import ( "context" "database/sql" "errors" "fmt" "github.com/pressly/goose/v3/internal/dialect/dialectquery" ) // Dialect is the type of database dialect. type Dialect string const ( DialectClickHouse Dialect = "clickhouse" DialectMSSQL Dialect = "mssql" DialectMySQL Dialect = "mysql" DialectPostgres Dialect = "postgres" DialectRedshift Dialect = "redshift" DialectSQLite3 Dialect = "sqlite3" DialectTiDB Dialect = "tidb" DialectTurso Dialect = "turso" DialectVertica Dialect = "vertica" DialectYdB Dialect = "ydb" DialectStarrocks Dialect = "starrocks" ) // NewStore returns a new [Store] implementation for the given dialect. func NewStore(dialect Dialect, tablename string) (Store, error) { if tablename == "" { return nil, errors.New("table name must not be empty") } if dialect == "" { return nil, errors.New("dialect must not be empty") } lookup := map[Dialect]dialectquery.Querier{ DialectClickHouse: &dialectquery.Clickhouse{}, DialectMSSQL: &dialectquery.Sqlserver{}, DialectMySQL: &dialectquery.Mysql{}, DialectPostgres: &dialectquery.Postgres{}, DialectRedshift: &dialectquery.Redshift{}, DialectSQLite3: &dialectquery.Sqlite3{}, DialectTiDB: &dialectquery.Tidb{}, DialectVertica: &dialectquery.Vertica{}, DialectYdB: &dialectquery.Ydb{}, DialectTurso: &dialectquery.Turso{}, DialectStarrocks: &dialectquery.Starrocks{}, } querier, ok := lookup[dialect] if !ok { return nil, fmt.Errorf("unknown dialect: %q", dialect) } return &store{ tablename: tablename, querier: dialectquery.NewQueryController(querier), }, nil } type store struct { tablename string querier *dialectquery.QueryController } var _ Store = (*store)(nil) func (s *store) Tablename() string { return s.tablename } func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error { q := s.querier.CreateTable(s.tablename) if _, err := db.ExecContext(ctx, q); err != nil { return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) } return nil } func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error { q := s.querier.InsertVersion(s.tablename) if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil { return fmt.Errorf("failed to insert version %d: %w", req.Version, err) } return nil } func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error { q := s.querier.DeleteVersion(s.tablename) if _, err := db.ExecContext(ctx, q, version); err != nil { return fmt.Errorf("failed to delete version %d: %w", version, err) } return nil } func (s *store) GetMigration( ctx context.Context, db DBTxConn, version int64, ) (*GetMigrationResult, error) { q := s.querier.GetMigrationByVersion(s.tablename) var result GetMigrationResult if err := db.QueryRowContext(ctx, q, version).Scan( &result.Timestamp, &result.IsApplied, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version) } return nil, fmt.Errorf("failed to get migration %d: %w", version, err) } return &result, nil } func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) { q := s.querier.GetLatestVersion(s.tablename) var version sql.NullInt64 if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil { return -1, fmt.Errorf("failed to get latest version: %w", err) } if !version.Valid { return -1, fmt.Errorf("latest %w", ErrVersionNotFound) } return version.Int64, nil } func (s *store) ListMigrations( ctx context.Context, db DBTxConn, ) ([]*ListMigrationsResult, error) { q := s.querier.ListMigrations(s.tablename) rows, err := db.QueryContext(ctx, q) if err != nil { return nil, fmt.Errorf("failed to list migrations: %w", err) } defer rows.Close() var migrations []*ListMigrationsResult for rows.Next() { var result ListMigrationsResult if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { return nil, fmt.Errorf("failed to scan list migrations result: %w", err) } migrations = append(migrations, &result) } if err := rows.Err(); err != nil { return nil, err } return migrations, nil } // // // // Additional methods that are not part of the core Store interface, but are extended by the // [controller.StoreController] type. // // // func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) { q := s.querier.TableExists(s.tablename) if q == "" { return false, errors.ErrUnsupported } var exists bool // Note, we do not pass the table name as an argument to the query, as the query should be // pre-defined by the dialect. if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil { return false, fmt.Errorf("failed to check if table exists: %w", err) } return exists, nil }