diff --git a/cmd/goose/driver_mysql.go b/cmd/goose/driver_mysql.go index 3fa0126..bff3243 100644 --- a/cmd/goose/driver_mysql.go +++ b/cmd/goose/driver_mysql.go @@ -3,11 +3,14 @@ package main import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" "log" + "regexp" "github.com/go-sql-driver/mysql" - - _ "github.com/go-sql-driver/mysql" _ "github.com/ziutek/mymysql/godrv" ) @@ -15,10 +18,16 @@ import ( // the parameter `parseTime` set to true. This allows internal goose logic // to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time // type. -func normalizeDBString(driver string, str string) string { +func normalizeDBString(driver string, str string, certfile string) string { if driver == "mysql" { + var isTLS = certfile != "" + if isTLS { + if err := registerTLSConfig(certfile); err != nil { + log.Fatalf("goose run: %v", err) + } + } var err error - str, err = normalizeMySQLDSN(str) + str, err = normalizeMySQLDSN(str, isTLS) if err != nil { log.Fatalf("failed to normalize MySQL connection string: %v", err) } @@ -26,11 +35,35 @@ func normalizeDBString(driver string, str string) string { return str } -func normalizeMySQLDSN(dsn string) (string, error) { +const tlsConfigKey = "custom" + +var tlsReg = regexp.MustCompile(`(\?|&)tls=[^&]*(?:&|$)`) + +func normalizeMySQLDSN(dsn string, tls bool) (string, error) { + // If we are sharing a DSN in a different environment, it may contain a TLS + // setting key with a value name that is not "custom," so clear it. + dsn = tlsReg.ReplaceAllString(dsn, `$1`) config, err := mysql.ParseDSN(dsn) if err != nil { return "", err } config.ParseTime = true + if tls { + config.TLSConfig = tlsConfigKey + } return config.FormatDSN(), nil } + +func registerTLSConfig(pemfile string) error { + rootCertPool := x509.NewCertPool() + pem, err := ioutil.ReadFile(pemfile) + if err != nil { + return err + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + return fmt.Errorf("failed to append PEM: %q", pemfile) + } + return mysql.RegisterTLSConfig(tlsConfigKey, &tls.Config{ + RootCAs: rootCertPool, + }) +} diff --git a/cmd/goose/driver_no_mysql.go b/cmd/goose/driver_no_mysql.go index a18ccff..5603ade 100644 --- a/cmd/goose/driver_no_mysql.go +++ b/cmd/goose/driver_no_mysql.go @@ -7,6 +7,6 @@ import ( _ "github.com/ziutek/mymysql/godrv" ) -func normalizeDBString(driver string, str string) string { +func normalizeDBString(driver string, str string, certfile string) string { return str } diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 8516b53..1b9e1fd 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -10,12 +10,13 @@ import ( ) var ( - flags = flag.NewFlagSet("goose", flag.ExitOnError) - dir = flags.String("dir", ".", "directory with migration files") - table = flags.String("table", "goose_db_version", "migrations table name") - verbose = flags.Bool("v", false, "enable verbose mode") - help = flags.Bool("h", false, "print help") - version = flags.Bool("version", false, "print version") + flags = flag.NewFlagSet("goose", flag.ExitOnError) + dir = flags.String("dir", ".", "directory with migration files") + table = flags.String("table", "goose_db_version", "migrations table name") + verbose = flags.Bool("v", false, "enable verbose mode") + help = flags.Bool("h", false, "print help") + version = flags.Bool("version", false, "print version") + certfile = flags.String("certfile", "", "file path to root CA's certificates in pem format (only support on mysql)") ) func main() { @@ -58,7 +59,7 @@ func main() { driver, dbstring, command := args[0], args[1], args[2] - db, err := goose.OpenDBWithDriver(driver, normalizeDBString(driver, dbstring)) + db, err := goose.OpenDBWithDriver(driver, normalizeDBString(driver, dbstring, *certfile)) if err != nil { log.Fatalf("-dbstring=%q: %v\n", dbstring, err) }