automatically set TLS config value to DSN when using TLS on mysql

pull/180/head
Songmu 2019-06-20 17:16:52 +09:00
parent cf288525b2
commit 521de29112
3 changed files with 12 additions and 8 deletions

View File

@ -17,10 +17,10 @@ import (
// the parameter `parseTime` set to true. This allows internal goose logic // the parameter `parseTime` set to true. This allows internal goose logic
// to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time // to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time
// type. // type.
func normalizeDBString(driver string, str string) string { func normalizeDBString(driver string, str string, tls bool) string {
if driver == "mysql" { if driver == "mysql" {
var err error var err error
str, err = normalizeMySQLDSN(str) str, err = normalizeMySQLDSN(str, tls)
if err != nil { if err != nil {
log.Fatalf("failed to normalize MySQL connection string: %v", err) log.Fatalf("failed to normalize MySQL connection string: %v", err)
} }
@ -28,17 +28,20 @@ func normalizeDBString(driver string, str string) string {
return str return str
} }
func normalizeMySQLDSN(dsn string) (string, error) { const tlsConfigKey = "custom"
func normalizeMySQLDSN(dsn string, tls bool) (string, error) {
config, err := mysql.ParseDSN(dsn) config, err := mysql.ParseDSN(dsn)
if err != nil { if err != nil {
return "", err return "", err
} }
config.ParseTime = true config.ParseTime = true
if tls {
config.TLSConfig = tlsConfigKey
}
return config.FormatDSN(), nil return config.FormatDSN(), nil
} }
const tlsConfigKey = "custom"
func registerTLSConfig(pemfile string) error { func registerTLSConfig(pemfile string) error {
rootCertPool := x509.NewCertPool() rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(pemfile) pem, err := ioutil.ReadFile(pemfile)

View File

@ -7,7 +7,7 @@ import (
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
) )
func normalizeDBString(driver string, str string) string { func normalizeDBString(driver string, str string, tls bool) string {
return str return str
} }

View File

@ -36,7 +36,8 @@ func main() {
return return
} }
if *sslCA != "" { tls := *sslCA != ""
if tls {
if err := registerTLSConfig(*sslCA); err != nil { if err := registerTLSConfig(*sslCA); err != nil {
log.Fatalf("goose run: %v", err) log.Fatalf("goose run: %v", err)
} }
@ -62,7 +63,7 @@ func main() {
driver, dbstring, command := args[0], args[1], args[2] driver, dbstring, command := args[0], args[1], args[2]
db, err := goose.OpenDBWithDriver(driver, normalizeDBString(driver, dbstring)) db, err := goose.OpenDBWithDriver(driver, normalizeDBString(driver, dbstring, tls))
if err != nil { if err != nil {
log.Fatalf("-dbstring=%q: %v\n", dbstring, err) log.Fatalf("-dbstring=%q: %v\n", dbstring, err)
} }