From cdd2cc41244843d1aaaf47efa89ff1f8dce6e3c1 Mon Sep 17 00:00:00 2001 From: "yun.xu" Date: Tue, 19 Jul 2022 10:36:38 -0400 Subject: [PATCH] EC-2198 change for sslpassword --- config.go | 21 ++++++++---- pgconn.go | 12 +++++++ pgconn_test.go | 89 ++++++++++++++++++++++---------------------------- 3 files changed, 66 insertions(+), 56 deletions(-) diff --git a/config.go b/config.go index fa9e3801..2e038304 100644 --- a/config.go +++ b/config.go @@ -709,22 +709,31 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P } block, _ := pem.Decode(buf) var pemKey []byte + var decryptedKey []byte + var decryptedError error // If PEM is encrypted, attempt to decrypt using pass phrase if x509.IsEncryptedPEMBlock(block) { - if sslpassword == "" { + // Attempt decryption with pass phrase + // NOTE: only supports RSA (PKCS#1) + if(sslpassword != ""){ + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) + } + //if sslpassword not provided or has decryption error when use it + //try to find sslpassword with callback function + if (sslpassword == "" || decryptedError!= nil) { if(parseConfigOptions.GetSSLPassword != nil){ sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) - }else{ + } + if(sslpassword == ""){ return nil, fmt.Errorf("unable to find sslpassword") } } - // Attempt decryption with pass phrase - // NOTE: only supports RSA (PKCS#1) - decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword)) + decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) // Should we also provide warning for PKCS#1 needed? - if err != nil { + if decryptedError != nil { return nil, fmt.Errorf("unable to decrypt key: %w", err) } + pemBytes := pem.Block{ Type: "RSA PRIVATE KEY", Bytes: decryptedKey, diff --git a/pgconn.go b/pgconn.go index 430f4367..f582f5b8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } +// Connect establishes a connection to a PostgreSQL server using the environment +// and connString (in URL or DSN format) and ParseConfigOptions +// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. +func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { + config, err := ParseConfigWithOptions(connString, parseConfigOptions) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with // ParseConfig. ctx can be used to cancel a connect attempt. // diff --git a/pgconn_test.go b/pgconn_test.go index d9adda99..9a52abf6 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,7 +1,6 @@ package pgconn_test import ( - "bufio" "bytes" "compress/gzip" "context" @@ -54,6 +53,35 @@ func TestConnect(t *testing.T) { } } +func TestConnectWithOption(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions) + require.NoError(t, err) + + closeConn(t, conn) + }) + } +} + // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure // connection. func TestConnectTLS(t *testing.T) { @@ -67,58 +95,14 @@ func TestConnectTLS(t *testing.T) { var conn *pgconn.PgConn var err error - isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=") - - if isSslPasswrodEmpty { - config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) - require.Nil(t, err) - - conn, err = pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - } else { - conn, err = pgconn.Connect(context.Background(), connString) - require.NoError(t, err) - } - - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } - - closeConn(t, conn) -} - -func GetSslPassword() string { - readFile, err := os.Open("data.txt") - if err != nil { - fmt.Println(err) - } - fileScanner := bufio.NewScanner(readFile) - fileScanner.Split(bufio.ScanLines) - for fileScanner.Scan() { - line := fileScanner.Text() - if strings.HasPrefix(line, "sslpassword=") { - index := len("sslpassword=") - line := line[index:] - return line - } - } - return "" -} - -func TestConnectTLSCallback(t *testing.T) { - t.Parallel() - - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } - - config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword) + var sslOptions pgconn.ParseConfigOptions + sslOptions.GetSSLPassword = GetSSLPassword + config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) require.Nil(t, err) - conn, err := pgconn.ConnectConfig(context.Background(), config) + conn, err = pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) - + if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") } @@ -2180,3 +2164,8 @@ func Example() { // 3 // SELECT 3 } + +func GetSSLPassword(ctx context.Context) string { + connString := os.Getenv("PGX_SSL_PASSWORD") + return connString +} \ No newline at end of file