EC-2198 change for sslpassword

pull/1281/head
yun.xu 2022-07-19 10:36:38 -04:00 committed by Jack Christensen
parent 7402796e02
commit cdd2cc4124
3 changed files with 66 additions and 56 deletions

View File

@ -709,22 +709,31 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
}
block, _ := pem.Decode(buf)
var pemKey []byte
var decryptedKey []byte
var decryptedError error
// If PEM is encrypted, attempt to decrypt using pass phrase
if x509.IsEncryptedPEMBlock(block) {
if sslpassword == "" {
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
if(sslpassword != ""){
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
}
//if sslpassword not provided or has decryption error when use it
//try to find sslpassword with callback function
if (sslpassword == "" || decryptedError!= nil) {
if(parseConfigOptions.GetSSLPassword != nil){
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
}else{
}
if(sslpassword == ""){
return nil, fmt.Errorf("unable to find sslpassword")
}
}
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword))
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if err != nil {
if decryptedError != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
}
pemBytes := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: decryptedKey,

View File

@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using the environment
// and connString (in URL or DSN format) and ParseConfigOptions
// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt.
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
if err != nil {
return nil, err
}
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
// ParseConfig. ctx can be used to cancel a connect attempt.
//

View File

@ -1,7 +1,6 @@
package pgconn_test
import (
"bufio"
"bytes"
"compress/gzip"
"context"
@ -54,6 +53,35 @@ func TestConnect(t *testing.T) {
}
}
func TestConnectWithOption(t *testing.T) {
tests := []struct {
name string
env string
}{
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
connString := os.Getenv(tt.env)
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", tt.env)
}
var sslOptions pgconn.ParseConfigOptions
sslOptions.GetSSLPassword = GetSSLPassword
conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
require.NoError(t, err)
closeConn(t, conn)
})
}
}
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
// connection.
func TestConnectTLS(t *testing.T) {
@ -67,58 +95,14 @@ func TestConnectTLS(t *testing.T) {
var conn *pgconn.PgConn
var err error
isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=")
if isSslPasswrodEmpty {
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
require.Nil(t, err)
conn, err = pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
} else {
conn, err = pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
}
if _, ok := conn.Conn().(*tls.Conn); !ok {
t.Error("not a TLS connection")
}
closeConn(t, conn)
}
func GetSslPassword() string {
readFile, err := os.Open("data.txt")
if err != nil {
fmt.Println(err)
}
fileScanner := bufio.NewScanner(readFile)
fileScanner.Split(bufio.ScanLines)
for fileScanner.Scan() {
line := fileScanner.Text()
if strings.HasPrefix(line, "sslpassword=") {
index := len("sslpassword=")
line := line[index:]
return line
}
}
return ""
}
func TestConnectTLSCallback(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
}
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
var sslOptions pgconn.ParseConfigOptions
sslOptions.GetSSLPassword = GetSSLPassword
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
require.Nil(t, err)
conn, err := pgconn.ConnectConfig(context.Background(), config)
conn, err = pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
if _, ok := conn.Conn().(*tls.Conn); !ok {
t.Error("not a TLS connection")
}
@ -2180,3 +2164,8 @@ func Example() {
// 3
// SELECT 3
}
func GetSSLPassword(ctx context.Context) string {
connString := os.Getenv("PGX_SSL_PASSWORD")
return connString
}