mirror of https://github.com/jackc/pgx.git
EC-2198 change for sslpassword
parent
7402796e02
commit
cdd2cc4124
21
config.go
21
config.go
|
@ -709,22 +709,31 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
|||
}
|
||||
block, _ := pem.Decode(buf)
|
||||
var pemKey []byte
|
||||
var decryptedKey []byte
|
||||
var decryptedError error
|
||||
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||
if x509.IsEncryptedPEMBlock(block) {
|
||||
if sslpassword == "" {
|
||||
// Attempt decryption with pass phrase
|
||||
// NOTE: only supports RSA (PKCS#1)
|
||||
if(sslpassword != ""){
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
}
|
||||
//if sslpassword not provided or has decryption error when use it
|
||||
//try to find sslpassword with callback function
|
||||
if (sslpassword == "" || decryptedError!= nil) {
|
||||
if(parseConfigOptions.GetSSLPassword != nil){
|
||||
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||
}else{
|
||||
}
|
||||
if(sslpassword == ""){
|
||||
return nil, fmt.Errorf("unable to find sslpassword")
|
||||
}
|
||||
}
|
||||
// Attempt decryption with pass phrase
|
||||
// NOTE: only supports RSA (PKCS#1)
|
||||
decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
// Should we also provide warning for PKCS#1 needed?
|
||||
if err != nil {
|
||||
if decryptedError != nil {
|
||||
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||
}
|
||||
|
||||
pemBytes := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: decryptedKey,
|
||||
|
|
12
pgconn.go
12
pgconn.go
|
@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
|
|||
return ConnectConfig(ctx, config)
|
||||
}
|
||||
|
||||
// Connect establishes a connection to a PostgreSQL server using the environment
|
||||
// and connString (in URL or DSN format) and ParseConfigOptions
|
||||
// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt.
|
||||
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
|
||||
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ConnectConfig(ctx, config)
|
||||
}
|
||||
|
||||
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
|
||||
// ParseConfig. ctx can be used to cancel a connect attempt.
|
||||
//
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package pgconn_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
|
@ -54,6 +53,35 @@ func TestConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectWithOption(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
env string
|
||||
}{
|
||||
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
|
||||
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
|
||||
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
connString := os.Getenv(tt.env)
|
||||
if connString == "" {
|
||||
t.Skipf("Skipping due to missing environment variable %v", tt.env)
|
||||
}
|
||||
var sslOptions pgconn.ParseConfigOptions
|
||||
sslOptions.GetSSLPassword = GetSSLPassword
|
||||
conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
|
||||
require.NoError(t, err)
|
||||
|
||||
closeConn(t, conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
|
||||
// connection.
|
||||
func TestConnectTLS(t *testing.T) {
|
||||
|
@ -67,58 +95,14 @@ func TestConnectTLS(t *testing.T) {
|
|||
var conn *pgconn.PgConn
|
||||
var err error
|
||||
|
||||
isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=")
|
||||
|
||||
if isSslPasswrodEmpty {
|
||||
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
|
||||
require.Nil(t, err)
|
||||
|
||||
conn, err = pgconn.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
conn, err = pgconn.Connect(context.Background(), connString)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
||||
t.Error("not a TLS connection")
|
||||
}
|
||||
|
||||
closeConn(t, conn)
|
||||
}
|
||||
|
||||
func GetSslPassword() string {
|
||||
readFile, err := os.Open("data.txt")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fileScanner := bufio.NewScanner(readFile)
|
||||
fileScanner.Split(bufio.ScanLines)
|
||||
for fileScanner.Scan() {
|
||||
line := fileScanner.Text()
|
||||
if strings.HasPrefix(line, "sslpassword=") {
|
||||
index := len("sslpassword=")
|
||||
line := line[index:]
|
||||
return line
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestConnectTLSCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
|
||||
if connString == "" {
|
||||
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
||||
}
|
||||
|
||||
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
|
||||
var sslOptions pgconn.ParseConfigOptions
|
||||
sslOptions.GetSSLPassword = GetSSLPassword
|
||||
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
|
||||
require.Nil(t, err)
|
||||
|
||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
conn, err = pgconn.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
||||
t.Error("not a TLS connection")
|
||||
}
|
||||
|
@ -2180,3 +2164,8 @@ func Example() {
|
|||
// 3
|
||||
// SELECT 3
|
||||
}
|
||||
|
||||
func GetSSLPassword(ctx context.Context) string {
|
||||
connString := os.Getenv("PGX_SSL_PASSWORD")
|
||||
return connString
|
||||
}
|
Loading…
Reference in New Issue