mirror of https://github.com/jackc/pgx.git
EC-2198 change for sslpassword
parent
7402796e02
commit
cdd2cc4124
21
config.go
21
config.go
|
@ -709,22 +709,31 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||||
}
|
}
|
||||||
block, _ := pem.Decode(buf)
|
block, _ := pem.Decode(buf)
|
||||||
var pemKey []byte
|
var pemKey []byte
|
||||||
|
var decryptedKey []byte
|
||||||
|
var decryptedError error
|
||||||
// If PEM is encrypted, attempt to decrypt using pass phrase
|
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||||
if x509.IsEncryptedPEMBlock(block) {
|
if x509.IsEncryptedPEMBlock(block) {
|
||||||
if sslpassword == "" {
|
// Attempt decryption with pass phrase
|
||||||
|
// NOTE: only supports RSA (PKCS#1)
|
||||||
|
if(sslpassword != ""){
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
}
|
||||||
|
//if sslpassword not provided or has decryption error when use it
|
||||||
|
//try to find sslpassword with callback function
|
||||||
|
if (sslpassword == "" || decryptedError!= nil) {
|
||||||
if(parseConfigOptions.GetSSLPassword != nil){
|
if(parseConfigOptions.GetSSLPassword != nil){
|
||||||
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||||
}else{
|
}
|
||||||
|
if(sslpassword == ""){
|
||||||
return nil, fmt.Errorf("unable to find sslpassword")
|
return nil, fmt.Errorf("unable to find sslpassword")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Attempt decryption with pass phrase
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
// NOTE: only supports RSA (PKCS#1)
|
|
||||||
decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword))
|
|
||||||
// Should we also provide warning for PKCS#1 needed?
|
// Should we also provide warning for PKCS#1 needed?
|
||||||
if err != nil {
|
if decryptedError != nil {
|
||||||
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pemBytes := pem.Block{
|
pemBytes := pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: decryptedKey,
|
Bytes: decryptedKey,
|
||||||
|
|
12
pgconn.go
12
pgconn.go
|
@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
|
||||||
return ConnectConfig(ctx, config)
|
return ConnectConfig(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Connect establishes a connection to a PostgreSQL server using the environment
|
||||||
|
// and connString (in URL or DSN format) and ParseConfigOptions
|
||||||
|
// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt.
|
||||||
|
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
|
||||||
|
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ConnectConfig(ctx, config)
|
||||||
|
}
|
||||||
|
|
||||||
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
|
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
|
||||||
// ParseConfig. ctx can be used to cancel a connect attempt.
|
// ParseConfig. ctx can be used to cancel a connect attempt.
|
||||||
//
|
//
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package pgconn_test
|
package pgconn_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
|
@ -54,6 +53,35 @@ func TestConnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectWithOption(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
env string
|
||||||
|
}{
|
||||||
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||||
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||||
|
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
|
||||||
|
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
|
||||||
|
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
connString := os.Getenv(tt.env)
|
||||||
|
if connString == "" {
|
||||||
|
t.Skipf("Skipping due to missing environment variable %v", tt.env)
|
||||||
|
}
|
||||||
|
var sslOptions pgconn.ParseConfigOptions
|
||||||
|
sslOptions.GetSSLPassword = GetSSLPassword
|
||||||
|
conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
closeConn(t, conn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
|
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
|
||||||
// connection.
|
// connection.
|
||||||
func TestConnectTLS(t *testing.T) {
|
func TestConnectTLS(t *testing.T) {
|
||||||
|
@ -67,58 +95,14 @@ func TestConnectTLS(t *testing.T) {
|
||||||
var conn *pgconn.PgConn
|
var conn *pgconn.PgConn
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=")
|
var sslOptions pgconn.ParseConfigOptions
|
||||||
|
sslOptions.GetSSLPassword = GetSSLPassword
|
||||||
if isSslPasswrodEmpty {
|
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
|
||||||
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
conn, err = pgconn.ConnectConfig(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
conn, err = pgconn.Connect(context.Background(), connString)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
|
||||||
t.Error("not a TLS connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
closeConn(t, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSslPassword() string {
|
|
||||||
readFile, err := os.Open("data.txt")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
fileScanner := bufio.NewScanner(readFile)
|
|
||||||
fileScanner.Split(bufio.ScanLines)
|
|
||||||
for fileScanner.Scan() {
|
|
||||||
line := fileScanner.Text()
|
|
||||||
if strings.HasPrefix(line, "sslpassword=") {
|
|
||||||
index := len("sslpassword=")
|
|
||||||
line := line[index:]
|
|
||||||
return line
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectTLSCallback(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
|
|
||||||
if connString == "" {
|
|
||||||
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
conn, err = pgconn.ConnectConfig(context.Background(), config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
||||||
t.Error("not a TLS connection")
|
t.Error("not a TLS connection")
|
||||||
}
|
}
|
||||||
|
@ -2180,3 +2164,8 @@ func Example() {
|
||||||
// 3
|
// 3
|
||||||
// SELECT 3
|
// SELECT 3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetSSLPassword(ctx context.Context) string {
|
||||||
|
connString := os.Getenv("PGX_SSL_PASSWORD")
|
||||||
|
return connString
|
||||||
|
}
|
Loading…
Reference in New Issue