Add support for SslPassword

pull/1281/head
Eric McCormack 2022-06-21 11:19:46 -04:00 committed by Jack Christensen
parent a18df2374a
commit 32ec44f726
4 changed files with 116 additions and 6 deletions

9
.idea/pgconn.iml Normal file
View File

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
@ -60,6 +61,9 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// SslPasswordCallback is a callback function to handle Auth callback for SSL Password
SslPasswordCallback SslPasswordCallbackHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
@ -132,6 +136,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
return network, address
}
// ParseConfig builds a *Config when sslpasswordcallback function is not provided
func ParseConfig(connString string) (*Config, error) {
return ParseConfigWithSslPasswordCallback(connString, nil)
}
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
@ -171,6 +180,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGSSLPASSWORD
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
@ -194,6 +204,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLCConfig.
// sslPasswordCallback function provide a callback function for sslpassword
//
// Other known differences with libpq:
//
@ -207,7 +218,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
func ParseConfigWithSslPasswordCallback(connString string, sslPasswordCallback SslPasswordCallbackHandler) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
@ -278,6 +289,7 @@ func ParseConfig(connString string) (*Config, error) {
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
@ -326,7 +338,7 @@ func ParseConfig(connString string) (*Config, error) {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host)
tlsConfigs, err = configTLS(settings, host, sslPasswordCallback)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
@ -406,6 +418,7 @@ func parseEnvSettings() map[string]string {
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
@ -592,12 +605,13 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) {
func configTLS(settings map[string]string, thisHost string, sslPasswordCallback SslPasswordCallbackHandler) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
// Match libpq default behavior
if sslmode == "" {
@ -685,11 +699,43 @@ func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, erro
}
if sslcert != "" && sslkey != "" {
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
buf, err := ioutil.ReadFile(sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
var pemKey []byte
// If PEM is encrypted, attempt to decrypt using pass phrase
if x509.IsEncryptedPEMBlock(block) {
if sslpassword == "" {
if sslPasswordCallback == nil {
return nil, fmt.Errorf("unable to find sslpassword: %w", err)
}
sslpassword = sslPasswordCallback()
}
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if err != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
}
pemBytes := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: decryptedKey,
}
pemKey = pem.EncodeToMemory(&pemBytes)
} else {
pemKey = pem.EncodeToMemory(block)
}
certfile, err := ioutil.ReadFile(sslcert)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
cert, err := tls.X509KeyPair(certfile, pemKey)
if err != nil {
return nil, fmt.Errorf("unable to load cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}

View File

@ -64,6 +64,8 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event.
type NotificationHandler func(*PgConn, *Notification)
type SslPasswordCallbackHandler func() (string)
// Frontend used to receive messages from backend.
type Frontend interface {
Receive() (pgproto3.BackendMessage, error)

View File

@ -1,6 +1,7 @@
package pgconn_test
import (
"bufio"
"bytes"
"compress/gzip"
"context"
@ -63,7 +64,59 @@ func TestConnectTLS(t *testing.T) {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
}
conn, err := pgconn.Connect(context.Background(), connString)
var conn *pgconn.PgConn
var err error
isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=")
if isSslPasswrodEmpty {
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
require.Nil(t, err)
conn, err = pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
} else {
conn, err = pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
}
if _, ok := conn.Conn().(*tls.Conn); !ok {
t.Error("not a TLS connection")
}
closeConn(t, conn)
}
func GetSslPassword() string {
readFile, err := os.Open("data.txt")
if err != nil {
fmt.Println(err)
}
fileScanner := bufio.NewScanner(readFile)
fileScanner.Split(bufio.ScanLines)
for fileScanner.Scan() {
line := fileScanner.Text()
if strings.HasPrefix(line, "sslpassword=") {
index := len("sslpassword=")
line := line[index:]
return line
}
}
return ""
}
func TestConnectTLSCallback(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
}
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
require.Nil(t, err)
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
if _, ok := conn.Conn().(*tls.Conn); !ok {