From d7f24b91f4d9ce819213e47e333d585598f17d37 Mon Sep 17 00:00:00 2001 From: Sean Chittenden Date: Thu, 1 Feb 2018 22:58:14 -0800 Subject: [PATCH] Make ParseURI() compatible with lib/pq's TLS keywords. Add support for: - `sslrootcert` - `sslcert` - `sslkey` All three arguments, like thir `gitub.com/lib/pq` counterparts, are filesystem paths. --- conn.go | 41 ++++++++++++++++++++++++++++++++++++++++- conn_test.go | 3 ++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index efeba5ff..b9224200 100644 --- a/conn.go +++ b/conn.go @@ -4,10 +4,12 @@ import ( "context" "crypto/md5" "crypto/tls" + "crypto/x509" "encoding/binary" "encoding/hex" "fmt" "io" + "io/ioutil" "net" "net/url" "os" @@ -706,9 +708,46 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, err } + // Extract optional TLS parameters and reconstruct a coherent tls.Config based + // on the DSN input. Reuse the same keywords found in github.com/lib/pq. + if cp.TLSConfig != nil { + { + caCertPool := x509.NewCertPool() + + caPath := url.Query().Get("sslrootcert") + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return cp, errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return cp, errors.Wrap(err, "unable to add CA to cert pool") + } + + cp.TLSConfig.RootCAs = caCertPool + cp.TLSConfig.ClientCAs = caCertPool + } + + sslcert := url.Query().Get("sslcert") + sslkey := url.Query().Get("sslkey") + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return cp, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return cp, errors.Wrap(err, "unable to read cert") + } + + cp.TLSConfig.Certificates = []tls.Certificate{cert} + } + ignoreKeys := map[string]struct{}{ - "sslmode": {}, "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, } cp.RuntimeParams = make(map[string]string) diff --git a/conn_test.go b/conn_test.go index 6144521d..6f1d41ea 100644 --- a/conn_test.go +++ b/conn_test.go @@ -228,7 +228,8 @@ func TestConnectWithTLSFallback(t *testing.T) { } connConfig.UseFallbackTLS = true - connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + connConfig.FallbackTLSConfig = tlsConnConfig.TLSConfig + connConfig.FallbackTLSConfig.InsecureSkipVerify = true conn, err = pgx.Connect(connConfig) if err != nil {