goose/internal/provider/collect.go

233 lines
7.1 KiB
Go

package provider
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
)
func NewSource(t MigrationType, fullpath string, version int64) Source {
return Source{
Type: t,
Fullpath: fullpath,
Version: version,
}
}
// fileSources represents a collection of migration files on the filesystem.
type fileSources struct {
sqlSources []Source
goSources []Source
}
// TODO(mf): remove?
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
switch t {
case TypeGo:
for _, source := range s.goSources {
if source.Version == version {
return &source
}
}
case TypeSQL:
for _, source := range s.sqlSources {
if source.Version == version {
return &source
}
}
}
return nil
}
// collectFileSources scans the file system for migration files that have a numeric prefix (greater
// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil,
// in which case an empty fileSources is returned.
//
// If strict is true, then any error parsing the numeric component of the filename will result in an
// error. The file is skipped otherwise.
//
// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects
// migration sources from the filesystem.
func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) {
if fsys == nil {
return new(fileSources), nil
}
sources := new(fileSources)
versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath)
for _, pattern := range []string{
"*.sql",
"*.go",
} {
files, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
}
for _, fullpath := range files {
base := filepath.Base(fullpath)
// Skip explicit excludes or Go test files.
if excludes[base] || strings.HasSuffix(base, "_test.go") {
continue
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := NumericComponent(base)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
continue
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing,
base,
)
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version))
case ".go":
sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version))
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("unknown migration type: %s", base)
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
}
}
return sources, nil
}
func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) {
var migrations []*migration
migrationLookup := make(map[int64]*migration)
// Add all SQL migrations to the list of migrations.
for _, source := range sources.sqlSources {
m := &migration{
Source: source,
SQL: nil, // SQL migrations are parsed lazily.
}
migrations = append(migrations, m)
migrationLookup[source.Version] = m
}
// If there are no Go files in the filesystem and no registered Go migrations, return early.
if len(sources.goSources) == 0 && len(registerd) == 0 {
return migrations, nil
}
// Return an error if the given sources contain a versioned Go migration that has not been
// registered. This is a sanity check to ensure users didn't accidentally create a valid looking
// Go migration file on disk and forget to register it.
//
// This is almost always a user error.
var unregistered []string
for _, s := range sources.goSources {
if _, ok := registerd[s.Version]; !ok {
unregistered = append(unregistered, s.Fullpath)
}
}
if len(unregistered) > 0 {
return nil, unregisteredError(unregistered)
}
// Add all registered Go migrations to the list of migrations, checking for duplicate versions.
//
// Important, users can register Go migrations manually via goose.Add_ functions. These
// migrations may not have a corresponding file on disk. Which is fine! We include them
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
// the SQL migration files.
for version, r := range registerd {
fullpath := r.fullpath
if fullpath == "" {
if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Fullpath
}
}
// Ensure there are no duplicate versions.
if existing, ok := migrationLookup[version]; ok {
fullpath := r.fullpath
if fullpath == "" {
fullpath = "manually registered (no source)"
}
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing.Source.Fullpath,
fullpath,
)
}
m := &migration{
// Note, the fullpath may be empty if the migration was registered manually.
Source: NewSource(TypeGo, fullpath, version),
Go: r,
}
migrations = append(migrations, m)
migrationLookup[version] = m
}
// Sort migrations by version in ascending order.
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Source.Version < migrations[j].Source.Version
})
return migrations, nil
}
func unregisteredError(unregistered []string) error {
const (
hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations"
)
f := "file"
if len(unregistered) > 1 {
f += "s"
}
var b strings.Builder
b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f))
for _, name := range unregistered {
b.WriteString("\t" + name + "\n")
}
hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL)
b.WriteString(hint)
b.WriteString("\n")
return errors.New(b.String())
}
type noopFS struct{}
var _ fs.FS = noopFS{}
func (f noopFS) Open(name string) (fs.File, error) {
return nil, os.ErrNotExist
}
// NumericComponent parses the version from the migration file name.
//
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
// migration, either .sql or .go.
func NumericComponent(filename string) (int64, error) {
base := filepath.Base(filename)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("migration file does not have .sql or .go file extension")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no filename separator '_' found")
}
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, err
}
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
}
return n, nil
}