Merge remote-tracking branch 'pgconn/master' into v5-dev

This commit is contained in:
Jack Christensen 2022-08-19 17:36:29 -05:00
commit ef5655c563
3 changed files with 289 additions and 1 deletions

View File

@ -291,6 +291,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
@ -417,6 +418,7 @@ func parseEnvSettings() map[string]string {
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGTARGETSESSIONATTRS": "target_session_attrs",
@ -612,11 +614,15 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
if sslsni == "" {
sslsni = "1"
}
tlsConfig := &tls.Config{}
@ -749,6 +755,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
tlsConfig.Certificates = []tls.Certificate{cert}
}
// Set Server Name Indication (SNI), if enabled by connection parameters.
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
// or IPv6).
if sslsni == "1" && net.ParseIP(host) == nil {
tlsConfig.ServerName = host
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil

View File

@ -53,6 +53,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
@ -94,6 +95,7 @@ func TestParseConfig(t *testing.T) {
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
},
},
@ -111,6 +113,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
@ -133,6 +136,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
},
@ -148,6 +152,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
},
RuntimeParams: map[string]string{},
},
@ -519,6 +524,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "foo",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
@ -532,6 +538,7 @@ func TestParseConfig(t *testing.T) {
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "bar",
}},
&pgconn.FallbackConfig{
Host: "bar",
@ -543,6 +550,7 @@ func TestParseConfig(t *testing.T) {
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "baz",
}},
&pgconn.FallbackConfig{
Host: "baz",
@ -648,6 +656,82 @@ func TestParseConfig(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is set by default",
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "sni.test",
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set for IPv4",
connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "1.1.1.1",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set for IPv6",
connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "::1",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set when disabled (URL-style)",
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set when disabled (key/value style)",
connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
}
for i, tt := range tests {
@ -820,7 +904,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
}
}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI"}
savedEnv := make(map[string]string)
for _, n := range pgEnvvars {
@ -884,6 +968,23 @@ func TestParseConfigEnvLibpq(t *testing.T) {
RuntimeParams: map[string]string{"application_name": "pgxtest"},
},
},
{
name: "SNI can be disabled via environment variable",
envvars: map[string]string{
"PGHOST": "test.foo",
"PGSSLMODE": "require",
"PGSSLSNI": "0",
},
config: &pgconn.Config{
User: osUserName,
Host: "test.foo",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
}
for i, tt := range tests {
@ -974,6 +1075,7 @@ application_name = spaced string
Port: 9999,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "abc.example.com",
},
RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{
@ -995,6 +1097,7 @@ application_name = spaced string
User: "defuser",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "def.example.com",
},
RuntimeParams: map[string]string{"application_name": "spaced string"},
Fallbacks: []*pgconn.FallbackConfig{

View File

@ -2631,3 +2631,175 @@ func GetSSLPassword(ctx context.Context) string {
connString := os.Getenv("PGX_SSL_PASSWORD")
return connString
}
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
MlascjupnaptKX/wMA==
-----END CERTIFICATE-----
`
var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
z3w+CgS20UrbLIR1YXfqUXge1g==
-----END TESTING KEY-----
`)
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
func TestSNISupport(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sni_param string
sni_set bool
}{
{
name: "SNI is passed by default",
sni_param: "",
sni_set: true,
},
{
name: "SNI is passed when asked for",
sni_param: "sslsni=1",
sni_set: true,
},
{
name: "SNI is not passed when disabled",
sni_param: "sslsni=0",
sni_set: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
serverSNINameChan := make(chan string, 1)
defer close(serverErrChan)
defer close(serverSNINameChan)
go func() {
var sniHost string
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
if err != nil {
serverErrChan <- err
return
}
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
startupMessage, err := backend.ReceiveStartupMessage()
if err != nil {
serverErrChan <- err
return
}
switch startupMessage.(type) {
case *pgproto3.SSLRequest:
_, err = conn.Write([]byte("S"))
if err != nil {
serverErrChan <- err
return
}
default:
serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
return
}
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
if err != nil {
serverErrChan <- err
return
}
srv := tls.Server(conn, &tls.Config{
Certificates: []tls.Certificate{cert},
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
sniHost = argHello.ServerName
return nil, nil
},
})
defer srv.Close()
if err := srv.Handshake(); err != nil {
serverErrChan <- fmt.Errorf("handshake: %v", err)
return
}
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
serverSNINameChan <- sniHost
}()
port := strings.Split(ln.Addr().String(), ":")[1]
connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
_, err = pgconn.Connect(context.Background(), connStr)
select {
case sniHost := <-serverSNINameChan:
if tt.sni_set {
require.Equal(t, sniHost, "localhost")
} else {
require.Equal(t, sniHost, "")
}
case err = <-serverErrChan:
t.Fatalf("server failed with error: %+v", err)
case <-time.After(time.Millisecond * 100):
t.Fatal("exceeded connection timeout without erroring out")
}
})
}
}