mirror of https://github.com/jackc/pgx.git
Add test coverage for client SNI
parent
067771b2e6
commit
e3406d95f9
172
pgconn_test.go
172
pgconn_test.go
|
@ -2169,3 +2169,175 @@ func GetSSLPassword(ctx context.Context) string {
|
|||
connString := os.Getenv("PGX_SSL_PASSWORD")
|
||||
return connString
|
||||
}
|
||||
|
||||
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
|
||||
MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
|
||||
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
|
||||
NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
|
||||
AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
|
||||
Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
|
||||
tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
|
||||
9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
|
||||
0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
|
||||
MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
|
||||
FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
|
||||
6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
|
||||
gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
|
||||
81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
|
||||
Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
|
||||
hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
|
||||
VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
|
||||
MlascjupnaptKX/wMA==
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
|
||||
ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
|
||||
Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
|
||||
bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
|
||||
qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
|
||||
Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
|
||||
o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
|
||||
WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
|
||||
ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
|
||||
Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
|
||||
QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
|
||||
QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
|
||||
CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
|
||||
bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
|
||||
1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
|
||||
SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
|
||||
MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
|
||||
McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
|
||||
I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
|
||||
QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
|
||||
k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
|
||||
lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
|
||||
TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
|
||||
5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
|
||||
UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
|
||||
z3w+CgS20UrbLIR1YXfqUXge1g==
|
||||
-----END TESTING KEY-----
|
||||
`)
|
||||
|
||||
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
|
||||
|
||||
func TestSNISupport(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
sni_param string
|
||||
sni_set bool
|
||||
}{
|
||||
{
|
||||
name: "SNI is passed by default",
|
||||
sni_param: "",
|
||||
sni_set: true,
|
||||
},
|
||||
{
|
||||
name: "SNI is passed when asked for",
|
||||
sni_param: "sslsni=1",
|
||||
sni_set: true,
|
||||
},
|
||||
{
|
||||
name: "SNI is not passed when disabled",
|
||||
sni_param: "sslsni=0",
|
||||
sni_set: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
serverSNINameChan := make(chan string, 1)
|
||||
defer close(serverErrChan)
|
||||
defer close(serverSNINameChan)
|
||||
|
||||
go func() {
|
||||
var sniHost string
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
|
||||
startupMessage, err := backend.ReceiveStartupMessage()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
switch startupMessage.(type) {
|
||||
case *pgproto3.SSLRequest:
|
||||
_, err = conn.Write([]byte("S"))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
default:
|
||||
serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
srv := tls.Server(conn, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
sniHost = argHello.ServerName
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
if err := srv.Handshake(); err != nil {
|
||||
serverErrChan <- fmt.Errorf("handshake: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
|
||||
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
|
||||
|
||||
serverSNINameChan <- sniHost
|
||||
}()
|
||||
|
||||
port := strings.Split(ln.Addr().String(), ":")[1]
|
||||
connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
|
||||
_, err = pgconn.Connect(context.Background(), connStr)
|
||||
|
||||
select {
|
||||
case sniHost := <-serverSNINameChan:
|
||||
if tt.sni_set {
|
||||
require.Equal(t, sniHost, "localhost")
|
||||
} else {
|
||||
require.Equal(t, sniHost, "")
|
||||
}
|
||||
case err = <-serverErrChan:
|
||||
t.Fatalf("server failed with error: %+v", err)
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
t.Fatal("exceeded connection timeout without erroring out")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue