diff --git a/create.go b/create.go index 91843e2..48bc7f9 100644 --- a/create.go +++ b/create.go @@ -10,7 +10,7 @@ import ( ) // Create writes a new blank migration file. -func Create(db *sql.DB, dir, name, migrationType string) error { +func CreateWithTemplate(db *sql.DB, dir string, migrationTemplate *template.Template, name, migrationType string) error { migrations, err := CollectMigrations(dir, minVersion, maxVersion) if err != nil { return err @@ -26,11 +26,16 @@ func Create(db *sql.DB, dir, name, migrationType string) error { filename := fmt.Sprintf("%v_%v.%v", version, name, migrationType) fpath := filepath.Join(dir, filename) + tmpl := sqlMigrationTemplate if migrationType == "go" { tmpl = goSQLMigrationTemplate } + if migrationTemplate != nil { + tmpl = migrationTemplate + } + path, err := writeTemplateToFile(fpath, tmpl, version) if err != nil { return err @@ -40,6 +45,11 @@ func Create(db *sql.DB, dir, name, migrationType string) error { return nil } +// Create writes a new blank migration file. +func Create(db *sql.DB, dir, name, migrationType string) error { + return CreateWithTemplate(db, dir, nil, name, migrationType) +} + func writeTemplateToFile(path string, t *template.Template, version string) (string, error) { if _, err := os.Stat(path); !os.IsNotExist(err) { return "", fmt.Errorf("failed to create file: %v already exists", path)