Make ParseURI() compatible with lib/pq's TLS keywords.

Add support for:

- `sslrootcert`
- `sslcert`
- `sslkey`

All three arguments, like thir `gitub.com/lib/pq` counterparts,
are filesystem paths.
This commit is contained in:
Sean Chittenden 2018-02-01 22:58:14 -08:00
parent 4506a3e359
commit d7f24b91f4
No known key found for this signature in database
GPG Key ID: 4EBC9DC16C2E5E16
2 changed files with 42 additions and 2 deletions

41
conn.go
View File

@ -4,10 +4,12 @@ import (
"context"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"os"
@ -706,9 +708,46 @@ func ParseURI(uri string) (ConnConfig, error) {
return cp, err
}
// Extract optional TLS parameters and reconstruct a coherent tls.Config based
// on the DSN input. Reuse the same keywords found in github.com/lib/pq.
if cp.TLSConfig != nil {
{
caCertPool := x509.NewCertPool()
caPath := url.Query().Get("sslrootcert")
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return cp, errors.Wrapf(err, "unable to read CA file %q", caPath)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return cp, errors.Wrap(err, "unable to add CA to cert pool")
}
cp.TLSConfig.RootCAs = caCertPool
cp.TLSConfig.ClientCAs = caCertPool
}
sslcert := url.Query().Get("sslcert")
sslkey := url.Query().Get("sslkey")
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return cp, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
}
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return cp, errors.Wrap(err, "unable to read cert")
}
cp.TLSConfig.Certificates = []tls.Certificate{cert}
}
ignoreKeys := map[string]struct{}{
"sslmode": {},
"connect_timeout": {},
"sslcert": {},
"sslkey": {},
"sslmode": {},
"sslrootcert": {},
}
cp.RuntimeParams = make(map[string]string)

View File

@ -228,7 +228,8 @@ func TestConnectWithTLSFallback(t *testing.T) {
}
connConfig.UseFallbackTLS = true
connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
connConfig.FallbackTLSConfig = tlsConnConfig.TLSConfig
connConfig.FallbackTLSConfig.InsecureSkipVerify = true
conn, err = pgx.Connect(connConfig)
if err != nil {