Add TLS arg parsing to ParseDSN().

Factor out the TLS cert handling and add it to `configTLS()` via a
`struct` argument.
pull/385/head
Sean Chittenden 2018-02-01 23:51:50 -08:00
parent d7f24b91f4
commit 8078930406
No known key found for this signature in database
GPG Key ID: 4EBC9DC16C2E5E16
1 changed files with 65 additions and 43 deletions

108
conn.go
View File

@ -703,45 +703,17 @@ func ParseURI(uri string) (ConnConfig, error) {
cp.Dial = d.Dial
}
err = configTLS(url.Query().Get("sslmode"), &cp)
tlsArgs := configTLSArgs{
sslCert: url.Query().Get("sslcert"),
sslKey: url.Query().Get("sslkey"),
sslMode: url.Query().Get("sslmode"),
sslRootCert: url.Query().Get("sslrootcert"),
}
err = configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}
// Extract optional TLS parameters and reconstruct a coherent tls.Config based
// on the DSN input. Reuse the same keywords found in github.com/lib/pq.
if cp.TLSConfig != nil {
{
caCertPool := x509.NewCertPool()
caPath := url.Query().Get("sslrootcert")
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return cp, errors.Wrapf(err, "unable to read CA file %q", caPath)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return cp, errors.Wrap(err, "unable to add CA to cert pool")
}
cp.TLSConfig.RootCAs = caCertPool
cp.TLSConfig.ClientCAs = caCertPool
}
sslcert := url.Query().Get("sslcert")
sslkey := url.Query().Get("sslkey")
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return cp, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
}
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return cp, errors.Wrap(err, "unable to read cert")
}
cp.TLSConfig.Certificates = []tls.Certificate{cert}
}
ignoreKeys := map[string]struct{}{
"connect_timeout": {},
"sslcert": {},
@ -783,7 +755,7 @@ func ParseDSN(s string) (ConnConfig, error) {
m := dsnRegexp.FindAllStringSubmatch(s, -1)
var sslmode string
tlsArgs := configTLSArgs{}
cp.RuntimeParams = make(map[string]string)
@ -804,7 +776,13 @@ func ParseDSN(s string) (ConnConfig, error) {
case "dbname":
cp.Database = b[2]
case "sslmode":
sslmode = b[2]
tlsArgs.sslMode = b[2]
case "sslrootcert":
tlsArgs.sslRootCert = b[2]
case "sslcert":
tlsArgs.sslCert = b[2]
case "sslkey":
tlsArgs.sslKey = b[2]
case "connect_timeout":
timeout, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
@ -818,7 +796,7 @@ func ParseDSN(s string) (ConnConfig, error) {
}
}
err := configTLS(sslmode, &cp)
err := configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}
@ -898,7 +876,7 @@ func ParseEnvLibpq() (ConnConfig, error) {
sslmode := os.Getenv("PGSSLMODE")
err := configTLS(sslmode, &cc)
err := configTLS(configTLSArgs{sslMode: sslmode}, &cc)
if err != nil {
return cc, err
}
@ -913,14 +891,27 @@ func ParseEnvLibpq() (ConnConfig, error) {
return cc, nil
}
func configTLS(sslmode string, cc *ConnConfig) error {
type configTLSArgs struct {
sslMode string
sslRootCert string
sslCert string
sslKey string
}
// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config.
// Inputs are parsed out and provided by ParseDSN() or ParseURI().
func configTLS(args configTLSArgs, cc *ConnConfig) error {
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
if args.sslMode == "" {
args.sslMode = "prefer"
}
switch sslmode {
switch args.sslMode {
case "disable":
cc.UseFallbackTLS = false
cc.TLSConfig = nil
cc.FallbackTLSConfig = nil
return nil
case "allow":
cc.UseFallbackTLS = true
cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
@ -938,6 +929,37 @@ func configTLS(sslmode string, cc *ConnConfig) error {
return errors.New("sslmode is invalid")
}
{
caCertPool := x509.NewCertPool()
caPath := args.sslRootCert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return errors.Wrapf(err, "unable to read CA file %q", caPath)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return errors.Wrap(err, "unable to add CA to cert pool")
}
cc.TLSConfig.RootCAs = caCertPool
cc.TLSConfig.ClientCAs = caCertPool
}
sslcert := args.sslCert
sslkey := args.sslKey
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return fmt.Errorf(`both "sslcert" and "sslkey" are required`)
}
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return errors.Wrap(err, "unable to read cert")
}
cc.TLSConfig.Certificates = []tls.Certificate{cert}
return nil
}