Merge pull request #180 from Songmu/tls

add `--certfile` option to support TLS connection on mysql
pull/211/head
Vojtech Vitek 2020-04-16 19:23:55 +02:00 committed by GitHub
commit 9ede98d097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 13 deletions

View File

@ -3,11 +3,14 @@
package main package main
import ( import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"log" "log"
"regexp"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
) )
@ -15,10 +18,16 @@ import (
// the parameter `parseTime` set to true. This allows internal goose logic // the parameter `parseTime` set to true. This allows internal goose logic
// to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time // to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time
// type. // type.
func normalizeDBString(driver string, str string) string { func normalizeDBString(driver string, str string, certfile string) string {
if driver == "mysql" { if driver == "mysql" {
var isTLS = certfile != ""
if isTLS {
if err := registerTLSConfig(certfile); err != nil {
log.Fatalf("goose run: %v", err)
}
}
var err error var err error
str, err = normalizeMySQLDSN(str) str, err = normalizeMySQLDSN(str, isTLS)
if err != nil { if err != nil {
log.Fatalf("failed to normalize MySQL connection string: %v", err) log.Fatalf("failed to normalize MySQL connection string: %v", err)
} }
@ -26,11 +35,35 @@ func normalizeDBString(driver string, str string) string {
return str return str
} }
func normalizeMySQLDSN(dsn string) (string, error) { const tlsConfigKey = "custom"
var tlsReg = regexp.MustCompile(`(\?|&)tls=[^&]*(?:&|$)`)
func normalizeMySQLDSN(dsn string, tls bool) (string, error) {
// If we are sharing a DSN in a different environment, it may contain a TLS
// setting key with a value name that is not "custom," so clear it.
dsn = tlsReg.ReplaceAllString(dsn, `$1`)
config, err := mysql.ParseDSN(dsn) config, err := mysql.ParseDSN(dsn)
if err != nil { if err != nil {
return "", err return "", err
} }
config.ParseTime = true config.ParseTime = true
if tls {
config.TLSConfig = tlsConfigKey
}
return config.FormatDSN(), nil return config.FormatDSN(), nil
} }
func registerTLSConfig(pemfile string) error {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(pemfile)
if err != nil {
return err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append PEM: %q", pemfile)
}
return mysql.RegisterTLSConfig(tlsConfigKey, &tls.Config{
RootCAs: rootCertPool,
})
}

View File

@ -7,6 +7,6 @@ import (
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
) )
func normalizeDBString(driver string, str string) string { func normalizeDBString(driver string, str string, certfile string) string {
return str return str
} }

View File

@ -10,12 +10,13 @@ import (
) )
var ( var (
flags = flag.NewFlagSet("goose", flag.ExitOnError) flags = flag.NewFlagSet("goose", flag.ExitOnError)
dir = flags.String("dir", ".", "directory with migration files") dir = flags.String("dir", ".", "directory with migration files")
table = flags.String("table", "goose_db_version", "migrations table name") table = flags.String("table", "goose_db_version", "migrations table name")
verbose = flags.Bool("v", false, "enable verbose mode") verbose = flags.Bool("v", false, "enable verbose mode")
help = flags.Bool("h", false, "print help") help = flags.Bool("h", false, "print help")
version = flags.Bool("version", false, "print version") version = flags.Bool("version", false, "print version")
certfile = flags.String("certfile", "", "file path to root CA's certificates in pem format (only support on mysql)")
) )
func main() { func main() {
@ -58,7 +59,7 @@ func main() {
driver, dbstring, command := args[0], args[1], args[2] driver, dbstring, command := args[0], args[1], args[2]
db, err := goose.OpenDBWithDriver(driver, normalizeDBString(driver, dbstring)) db, err := goose.OpenDBWithDriver(driver, normalizeDBString(driver, dbstring, *certfile))
if err != nil { if err != nil {
log.Fatalf("-dbstring=%q: %v\n", dbstring, err) log.Fatalf("-dbstring=%q: %v\n", dbstring, err)
} }