diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 1bd6fd2..bd9469b 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -55,24 +55,13 @@ func main() { driver, dbstring, command := args[0], args[1], args[2] - if err := goose.SetDialect(driver); err != nil { - log.Fatal(err) - } - - switch driver { - case "redshift": - driver = "postgres" - case "tidb": - driver = "mysql" - } - switch dbstring { case "": log.Fatalf("-dbstring=%q not supported\n", dbstring) default: } - db, err := createDBWithDriver(driver, dbstring) + db, err := goose.OpenDBWithDriver(driver, dbstring) if err != nil { log.Fatalf("-dbstring=%q: %v\n", dbstring, err) } diff --git a/cmd/goose/sql.go b/db.go similarity index 66% rename from cmd/goose/sql.go rename to db.go index f5c80a5..c4aeec5 100644 --- a/cmd/goose/sql.go +++ b/db.go @@ -1,4 +1,4 @@ -package main +package goose import ( "database/sql" @@ -20,7 +20,20 @@ func normalizeMySQLDSN(dsn string) (string, error) { return config.FormatDSN(), nil } -func createDBWithDriver(driver string, dbstring string) (*sql.DB, error) { +// OpenDBWithDriver creates a connection a database, and modifies goose +// internals to be compatible with the supplied driver by calling SetDialect. +func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) { + if err := SetDialect(driver); err != nil { + return nil, err + } + + switch driver { + case "redshift": + driver = "postgres" + case "tidb": + driver = "mysql" + } + switch driver { case "postgres", "sqlite3": return sql.Open(driver, dbstring) diff --git a/cmd/goose/sql_test.go b/db_test.go similarity index 99% rename from cmd/goose/sql_test.go rename to db_test.go index 74a8b92..dfc52b5 100644 --- a/cmd/goose/sql_test.go +++ b/db_test.go @@ -1,4 +1,4 @@ -package main +package goose import ( "testing" diff --git a/examples/go-migrations/main.go b/examples/go-migrations/main.go index be158d8..ce1f26b 100644 --- a/examples/go-migrations/main.go +++ b/examples/go-migrations/main.go @@ -1,7 +1,6 @@ package main import ( - "database/sql" "flag" "log" "os" @@ -55,28 +54,15 @@ func main() { driver, dbstring, command := args[0], args[1], args[2] - switch driver { - case "postgres", "mysql", "sqlite3", "redshift": - if err := goose.SetDialect(driver); err != nil { - log.Fatal(err) - } - default: - log.Fatalf("%q driver not supported\n", driver) - } - switch dbstring { case "": log.Fatalf("-dbstring=%q not supported\n", dbstring) default: } - if driver == "redshift" { - driver = "postgres" - } - - db, err := sql.Open(driver, dbstring) + db, err := goose.OpenDBWithDriver(driver, dbstring) if err != nil { - log.Fatalf("-dbstring=%q: %v\n", dbstring, err) + log.Fatalf("goose run: %v\n", err) } arguments := []string{}