diff --git a/conn.go b/conn.go index 125d9032..9509973b 100644 --- a/conn.go +++ b/conn.go @@ -4,10 +4,12 @@ import ( "context" "crypto/md5" "crypto/tls" + "crypto/x509" "encoding/binary" "encoding/hex" "fmt" "io" + "io/ioutil" "net" "net/url" "os" @@ -701,14 +703,23 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Dial = d.Dial } - err = configSSL(url.Query().Get("sslmode"), &cp) + tlsArgs := configTLSArgs{ + sslCert: url.Query().Get("sslcert"), + sslKey: url.Query().Get("sslkey"), + sslMode: url.Query().Get("sslmode"), + sslRootCert: url.Query().Get("sslrootcert"), + } + err = configTLS(tlsArgs, &cp) if err != nil { return cp, err } ignoreKeys := map[string]struct{}{ - "sslmode": {}, "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, } cp.RuntimeParams = make(map[string]string) @@ -744,7 +755,7 @@ func ParseDSN(s string) (ConnConfig, error) { m := dsnRegexp.FindAllStringSubmatch(s, -1) - var sslmode string + tlsArgs := configTLSArgs{} cp.RuntimeParams = make(map[string]string) @@ -765,7 +776,13 @@ func ParseDSN(s string) (ConnConfig, error) { case "dbname": cp.Database = b[2] case "sslmode": - sslmode = b[2] + tlsArgs.sslMode = b[2] + case "sslrootcert": + tlsArgs.sslRootCert = b[2] + case "sslcert": + tlsArgs.sslCert = b[2] + case "sslkey": + tlsArgs.sslKey = b[2] case "connect_timeout": timeout, err := strconv.ParseInt(b[2], 10, 64) if err != nil { @@ -779,7 +796,7 @@ func ParseDSN(s string) (ConnConfig, error) { } } - err := configSSL(sslmode, &cp) + err := configTLS(tlsArgs, &cp) if err != nil { return cp, err } @@ -859,7 +876,7 @@ func ParseEnvLibpq() (ConnConfig, error) { sslmode := os.Getenv("PGSSLMODE") - err := configSSL(sslmode, &cc) + err := configTLS(configTLSArgs{sslMode: sslmode}, &cc) if err != nil { return cc, err } @@ -874,14 +891,27 @@ func ParseEnvLibpq() (ConnConfig, error) { return cc, nil } -func configSSL(sslmode string, cc *ConnConfig) error { +type configTLSArgs struct { + sslMode string + sslRootCert string + sslCert string + sslKey string +} + +// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config. +// Inputs are parsed out and provided by ParseDSN() or ParseURI(). +func configTLS(args configTLSArgs, cc *ConnConfig) error { // Match libpq default behavior - if sslmode == "" { - sslmode = "prefer" + if args.sslMode == "" { + args.sslMode = "prefer" } - switch sslmode { + switch args.sslMode { case "disable": + cc.UseFallbackTLS = false + cc.TLSConfig = nil + cc.FallbackTLSConfig = nil + return nil case "allow": cc.UseFallbackTLS = true cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} @@ -899,6 +929,39 @@ func configSSL(sslmode string, cc *ConnConfig) error { return errors.New("sslmode is invalid") } + if args.sslRootCert != "" { + caCertPool := x509.NewCertPool() + + caPath := args.sslRootCert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return errors.Wrap(err, "unable to add CA to cert pool") + } + + cc.TLSConfig.RootCAs = caCertPool + cc.TLSConfig.ClientCAs = caCertPool + } + + sslcert := args.sslCert + sslkey := args.sslKey + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return errors.Wrap(err, "unable to read cert") + } + + cc.TLSConfig.Certificates = []tls.Certificate{cert} + } + return nil } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 463c0841..096e1354 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -1,7 +1,14 @@ package pgx_test import ( - "github.com/jackc/pgx" + // "crypto/tls" + // "crypto/x509" + // "fmt" + // "go/build" + // "io/ioutil" + // "path" + + "github.com/jackc/pgx" ) var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} @@ -22,7 +29,51 @@ var cratedbConnConfig *pgx.ConnConfig = nil // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} -// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} // var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} +// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} +// +//// or to test client certs: +// +// var tlsConnConfig *pgx.ConnConfig +// +// func init() { +// homeDir := build.Default.GOPATH +// tlsConnConfig = &pgx.ConnConfig{ +// Host: "127.0.0.1", +// User: "pgx_md5", +// Password: "secret", +// Database: "pgx_test", +// TLSConfig: &tls.Config{ +// InsecureSkipVerify: true, +// }, +// } +// caCertPool := x509.NewCertPool() +// +// caPath := path.Join(homeDir, "/src/github.com/jackc/pgx/rootCA.pem") +// caCert, err := ioutil.ReadFile(caPath) +// if err != nil { +// panic(fmt.Sprintf("unable to read CA file: %v", err)) +// } +// +// if !caCertPool.AppendCertsFromPEM(caCert) { +// panic("unable to add CA to cert pool") +// } +// +// tlsConnConfig.TLSConfig.RootCAs = caCertPool +// tlsConnConfig.TLSConfig.ClientCAs = caCertPool +// +// sslCert := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.crt") +// sslKey := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.key") +// if (sslCert != "" && sslKey == "") || (sslCert == "" && sslKey != "") { +// panic(`both "sslcert" and "sslkey" are required`) +// } +// +// cert, err := tls.LoadX509KeyPair(sslCert, sslKey) +// if err != nil { +// panic(fmt.Sprintf("unable to read cert: %v", err)) +// } +// +// tlsConnConfig.TLSConfig.Certificates = []tls.Certificate{cert} +// } diff --git a/conn_test.go b/conn_test.go index 6144521d..6f1d41ea 100644 --- a/conn_test.go +++ b/conn_test.go @@ -228,7 +228,8 @@ func TestConnectWithTLSFallback(t *testing.T) { } connConfig.UseFallbackTLS = true - connConfig.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + connConfig.FallbackTLSConfig = tlsConnConfig.TLSConfig + connConfig.FallbackTLSConfig.InsecureSkipVerify = true conn, err = pgx.Connect(connConfig) if err != nil {