diff --git a/.idea/pgconn.iml b/.idea/pgconn.iml new file mode 100644 index 00000000..5e764c4f --- /dev/null +++ b/.idea/pgconn.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/config.go b/config.go index 8fd7efbf..12a48288 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "io" @@ -60,6 +61,9 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // SslPasswordCallback is a callback function to handle Auth callback for SSL Password + SslPasswordCallback SslPasswordCallbackHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -132,6 +136,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } +// ParseConfig builds a *Config when sslpasswordcallback function is not provided +func ParseConfig(connString string) (*Config, error) { + return ParseConfigWithSslPasswordCallback(connString, nil) +} + // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches // the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See @@ -171,6 +180,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLCERT // PGSSLKEY // PGSSLROOTCERT +// PGSSLPASSWORD // PGAPPNAME // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS @@ -194,6 +204,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting // TLCConfig. +// sslPasswordCallback function provide a callback function for sslpassword // // Other known differences with libpq: // @@ -207,7 +218,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // servicefile // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // part of the connection string. -func ParseConfig(connString string) (*Config, error) { +func ParseConfigWithSslPasswordCallback(connString string, sslPasswordCallback SslPasswordCallbackHandler) (*Config, error) { defaultSettings := defaultSettings() envSettings := parseEnvSettings() @@ -278,6 +289,7 @@ func ParseConfig(connString string) (*Config, error) { "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslpassword": {}, "krbspn": {}, "krbsrvname": {}, "target_session_attrs": {}, @@ -326,7 +338,7 @@ func ParseConfig(connString string) (*Config, error) { tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings, host) + tlsConfigs, err = configTLS(settings, host, sslPasswordCallback) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } @@ -406,6 +418,7 @@ func parseEnvSettings() map[string]string { "PGSSLKEY": "sslkey", "PGSSLCERT": "sslcert", "PGSSLROOTCERT": "sslrootcert", + "PGSSLPASSWORD": "sslpassword", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", @@ -592,12 +605,13 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. -func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { +func configTLS(settings map[string]string, thisHost string, sslPasswordCallback SslPasswordCallbackHandler) ([]*tls.Config, error) { host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] sslcert := settings["sslcert"] sslkey := settings["sslkey"] + sslpassword := settings["sslpassword"] // Match libpq default behavior if sslmode == "" { @@ -685,11 +699,43 @@ func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, erro } if sslcert != "" && sslkey != "" { - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + buf, err := ioutil.ReadFile(sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read sslkey: %w", err) + } + block, _ := pem.Decode(buf) + var pemKey []byte + // If PEM is encrypted, attempt to decrypt using pass phrase + if x509.IsEncryptedPEMBlock(block) { + if sslpassword == "" { + if sslPasswordCallback == nil { + return nil, fmt.Errorf("unable to find sslpassword: %w", err) + } + sslpassword = sslPasswordCallback() + } + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword)) + // Should we also provide warning for PKCS#1 needed? + if err != nil { + return nil, fmt.Errorf("unable to decrypt key: %w", err) + } + pemBytes := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: decryptedKey, + } + pemKey = pem.EncodeToMemory(&pemBytes) + } else { + pemKey = pem.EncodeToMemory(block) + } + certfile, err := ioutil.ReadFile(sslcert) if err != nil { return nil, fmt.Errorf("unable to read cert: %w", err) } - + cert, err := tls.X509KeyPair(certfile, pemKey) + if err != nil { + return nil, fmt.Errorf("unable to load cert: %w", err) + } tlsConfig.Certificates = []tls.Certificate{cert} } diff --git a/pgconn.go b/pgconn.go index 430f4367..67d6af38 100644 --- a/pgconn.go +++ b/pgconn.go @@ -64,6 +64,8 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +type SslPasswordCallbackHandler func() (string) + // Frontend used to receive messages from backend. type Frontend interface { Receive() (pgproto3.BackendMessage, error) diff --git a/pgconn_test.go b/pgconn_test.go index 32186fc6..d9adda99 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "bufio" "bytes" "compress/gzip" "context" @@ -63,7 +64,59 @@ func TestConnectTLS(t *testing.T) { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") } - conn, err := pgconn.Connect(context.Background(), connString) + var conn *pgconn.PgConn + var err error + + isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=") + + if isSslPasswrodEmpty { + config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) + require.Nil(t, err) + + conn, err = pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + } else { + conn, err = pgconn.Connect(context.Background(), connString) + require.NoError(t, err) + } + + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } + + closeConn(t, conn) +} + +func GetSslPassword() string { + readFile, err := os.Open("data.txt") + if err != nil { + fmt.Println(err) + } + fileScanner := bufio.NewScanner(readFile) + fileScanner.Split(bufio.ScanLines) + for fileScanner.Scan() { + line := fileScanner.Text() + if strings.HasPrefix(line, "sslpassword=") { + index := len("sslpassword=") + line := line[index:] + return line + } + } + return "" +} + +func TestConnectTLSCallback(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) + require.Nil(t, err) + + conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) if _, ok := conn.Conn().(*tls.Conn); !ok {