diff --git a/pgconn/config.go b/pgconn/config.go index e4a25cb8..c4ecf2c9 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "io" @@ -24,6 +25,7 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error +type GetSSLPasswordFunc func(ctx context.Context) string // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A // manually initialized Config will cause ConnectConfig to panic. @@ -62,6 +64,13 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// ParseConfigOptions contains options that control how a config is built such as getsslpassword. +type ParseConfigOptions struct { + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function + // PQsetSSLKeyPassHook_OpenSSL. + GetSSLPassword GetSSLPasswordFunc +} + // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the TLSConfig field: // according to the tls.Config docs it must not be modified after creation. @@ -131,17 +140,17 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same -// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches -// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See -// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It +// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // -// # Example DSN -// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // -// # Example URL -// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca // // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be @@ -152,27 +161,28 @@ func NetworkAddress(host string, port uint16) (network, address string) { // values that will be tried in order. This can be used as part of a high availability system. See // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // -// # Example URL -// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // via database URL or DSN: // -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGPASSFILE -// PGSERVICE -// PGSERVICEFILE -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT -// PGAPPNAME -// PGCONNECT_TIMEOUT -// PGTARGETSESSIONATTRS +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSERVICE +// PGSERVICEFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGSSLPASSWORD +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // @@ -192,7 +202,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting -// TLCConfig. +// TLSConfig. // // Other known differences with libpq: // @@ -201,10 +211,18 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // In addition, ParseConfig accepts the following options: // -// servicefile -// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a -// part of the connection string. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. func ParseConfig(connString string) (*Config, error) { + var parseConfigOptions ParseConfigOptions + return ParseConfigWithOptions(connString, parseConfigOptions) +} + +// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard +// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to +// get the SSL password. +func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { defaultSettings := defaultSettings() envSettings := parseEnvSettings() @@ -272,6 +290,7 @@ func ParseConfig(connString string) (*Config, error) { "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslpassword": {}, "krbspn": {}, "krbsrvname": {}, "target_session_attrs": {}, @@ -319,7 +338,7 @@ func ParseConfig(connString string) (*Config, error) { tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings, host) + tlsConfigs, err = configTLS(settings, host, options) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } @@ -399,6 +418,7 @@ func parseEnvSettings() map[string]string { "PGSSLKEY": "sslkey", "PGSSLCERT": "sslcert", "PGSSLROOTCERT": "sslrootcert", + "PGSSLPASSWORD": "sslpassword", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", @@ -585,12 +605,13 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. -func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { +func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] sslcert := settings["sslcert"] sslkey := settings["sslkey"] + sslpassword := settings["sslpassword"] // Match libpq default behavior if sslmode == "" { @@ -678,11 +699,53 @@ func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, erro } if sslcert != "" && sslkey != "" { - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + buf, err := ioutil.ReadFile(sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read sslkey: %w", err) + } + block, _ := pem.Decode(buf) + var pemKey []byte + var decryptedKey []byte + var decryptedError error + // If PEM is encrypted, attempt to decrypt using pass phrase + if x509.IsEncryptedPEMBlock(block) { + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + if sslpassword != "" { + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + } + //if sslpassword not provided or has decryption error when use it + //try to find sslpassword with callback function + if sslpassword == "" || decryptedError != nil { + if parseConfigOptions.GetSSLPassword != nil { + sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) + } + if sslpassword == "" { + return nil, fmt.Errorf("unable to find sslpassword") + } + } + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + // Should we also provide warning for PKCS#1 needed? + if decryptedError != nil { + return nil, fmt.Errorf("unable to decrypt key: %w", err) + } + + pemBytes := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: decryptedKey, + } + pemKey = pem.EncodeToMemory(&pemBytes) + } else { + pemKey = pem.EncodeToMemory(block) + } + certfile, err := ioutil.ReadFile(sslcert) if err != nil { return nil, fmt.Errorf("unable to read cert: %w", err) } - + cert, err := tls.X509KeyPair(certfile, pemKey) + if err != nil { + return nil, fmt.Errorf("unable to load cert: %w", err) + } tlsConfig.Certificates = []tls.Certificate{cert} } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index dcd9dc85..461548a2 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -98,6 +98,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be +// used to cancel a connect attempt. +func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { + config, err := ParseConfigWithOptions(connString, parseConfigOptions) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with // ParseConfig. ctx can be used to cancel a connect attempt. // diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 598d6629..2592f6a6 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -53,6 +53,35 @@ func TestConnect(t *testing.T) { } } +func TestConnectWithOptions(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure // connection. func TestConnectTLS(t *testing.T) { @@ -63,7 +92,15 @@ func TestConnectTLS(t *testing.T) { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") } - conn, err := pgconn.Connect(context.Background(), connString) + var conn *pgconn.PgConn + var err error + + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) + require.Nil(t, err) + + conn, err = pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) result := conn.ExecParams(context.Background(), `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read() @@ -2589,3 +2626,8 @@ func Example() { // 3 // SELECT 3 } + +func GetSSLPassword(ctx context.Context) string { + connString := os.Getenv("PGX_SSL_PASSWORD") + return connString +}