From 8078930406ebb0b7f182df3c9c0dfa23d133812e Mon Sep 17 00:00:00 2001 From: Sean Chittenden Date: Thu, 1 Feb 2018 23:51:50 -0800 Subject: [PATCH] Add TLS arg parsing to ParseDSN(). Factor out the TLS cert handling and add it to `configTLS()` via a `struct` argument. --- conn.go | 108 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 65 insertions(+), 43 deletions(-) diff --git a/conn.go b/conn.go index b9224200..8064a8be 100644 --- a/conn.go +++ b/conn.go @@ -703,45 +703,17 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Dial = d.Dial } - err = configTLS(url.Query().Get("sslmode"), &cp) + tlsArgs := configTLSArgs{ + sslCert: url.Query().Get("sslcert"), + sslKey: url.Query().Get("sslkey"), + sslMode: url.Query().Get("sslmode"), + sslRootCert: url.Query().Get("sslrootcert"), + } + err = configTLS(tlsArgs, &cp) if err != nil { return cp, err } - // Extract optional TLS parameters and reconstruct a coherent tls.Config based - // on the DSN input. Reuse the same keywords found in github.com/lib/pq. - if cp.TLSConfig != nil { - { - caCertPool := x509.NewCertPool() - - caPath := url.Query().Get("sslrootcert") - caCert, err := ioutil.ReadFile(caPath) - if err != nil { - return cp, errors.Wrapf(err, "unable to read CA file %q", caPath) - } - - if !caCertPool.AppendCertsFromPEM(caCert) { - return cp, errors.Wrap(err, "unable to add CA to cert pool") - } - - cp.TLSConfig.RootCAs = caCertPool - cp.TLSConfig.ClientCAs = caCertPool - } - - sslcert := url.Query().Get("sslcert") - sslkey := url.Query().Get("sslkey") - if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { - return cp, fmt.Errorf(`both "sslcert" and "sslkey" are required`) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - return cp, errors.Wrap(err, "unable to read cert") - } - - cp.TLSConfig.Certificates = []tls.Certificate{cert} - } - ignoreKeys := map[string]struct{}{ "connect_timeout": {}, "sslcert": {}, @@ -783,7 +755,7 @@ func ParseDSN(s string) (ConnConfig, error) { m := dsnRegexp.FindAllStringSubmatch(s, -1) - var sslmode string + tlsArgs := configTLSArgs{} cp.RuntimeParams = make(map[string]string) @@ -804,7 +776,13 @@ func ParseDSN(s string) (ConnConfig, error) { case "dbname": cp.Database = b[2] case "sslmode": - sslmode = b[2] + tlsArgs.sslMode = b[2] + case "sslrootcert": + tlsArgs.sslRootCert = b[2] + case "sslcert": + tlsArgs.sslCert = b[2] + case "sslkey": + tlsArgs.sslKey = b[2] case "connect_timeout": timeout, err := strconv.ParseInt(b[2], 10, 64) if err != nil { @@ -818,7 +796,7 @@ func ParseDSN(s string) (ConnConfig, error) { } } - err := configTLS(sslmode, &cp) + err := configTLS(tlsArgs, &cp) if err != nil { return cp, err } @@ -898,7 +876,7 @@ func ParseEnvLibpq() (ConnConfig, error) { sslmode := os.Getenv("PGSSLMODE") - err := configTLS(sslmode, &cc) + err := configTLS(configTLSArgs{sslMode: sslmode}, &cc) if err != nil { return cc, err } @@ -913,14 +891,27 @@ func ParseEnvLibpq() (ConnConfig, error) { return cc, nil } -func configTLS(sslmode string, cc *ConnConfig) error { +type configTLSArgs struct { + sslMode string + sslRootCert string + sslCert string + sslKey string +} + +// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config. +// Inputs are parsed out and provided by ParseDSN() or ParseURI(). +func configTLS(args configTLSArgs, cc *ConnConfig) error { // Match libpq default behavior - if sslmode == "" { - sslmode = "prefer" + if args.sslMode == "" { + args.sslMode = "prefer" } - switch sslmode { + switch args.sslMode { case "disable": + cc.UseFallbackTLS = false + cc.TLSConfig = nil + cc.FallbackTLSConfig = nil + return nil case "allow": cc.UseFallbackTLS = true cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} @@ -938,6 +929,37 @@ func configTLS(sslmode string, cc *ConnConfig) error { return errors.New("sslmode is invalid") } + { + caCertPool := x509.NewCertPool() + + caPath := args.sslRootCert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return errors.Wrap(err, "unable to add CA to cert pool") + } + + cc.TLSConfig.RootCAs = caCertPool + cc.TLSConfig.ClientCAs = caCertPool + } + + sslcert := args.sslCert + sslkey := args.sslKey + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return errors.Wrap(err, "unable to read cert") + } + + cc.TLSConfig.Certificates = []tls.Certificate{cert} + return nil }