From e3406d95f9be53211a295ee4b8f7208a27f48f05 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 16 Aug 2022 02:03:11 +0300 Subject: [PATCH] Add test coverage for client SNI --- pgconn_test.go | 172 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index a4f0ec63..302dffd5 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2169,3 +2169,175 @@ func GetSSLPassword(ctx context.Context) string { connString := os.Getenv("PGX_SSL_PASSWORD") return connString } + +var rsaCertPEM = `-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx +NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct +Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39 +tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d +9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp +0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv +MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E +FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o +6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2 +gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I +81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB +Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf +hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS +VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27 +MlascjupnaptKX/wMA== +-----END CERTIFICATE----- +` + +var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv +ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx +Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf +bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo +qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM +Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK +o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs +WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa +ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv +Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B +QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+ +QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC +CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods +bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3 +1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2 +SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6 +MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G +McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC +I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD +QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf +k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS +lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4 +TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr +5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi +UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T +z3w+CgS20UrbLIR1YXfqUXge1g== +-----END TESTING KEY----- +`) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +func TestSNISupport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + sni_param string + sni_set bool + }{ + { + name: "SNI is passed by default", + sni_param: "", + sni_set: true, + }, + { + name: "SNI is passed when asked for", + sni_param: "sslsni=1", + sni_set: true, + }, + { + name: "SNI is not passed when disabled", + sni_param: "sslsni=0", + sni_set: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + serverSNINameChan := make(chan string, 1) + defer close(serverErrChan) + defer close(serverSNINameChan) + + go func() { + var sniHost string + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) + startupMessage, err := backend.ReceiveStartupMessage() + if err != nil { + serverErrChan <- err + return + } + + switch startupMessage.(type) { + case *pgproto3.SSLRequest: + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrChan <- err + return + } + default: + serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } + + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrChan <- err + return + } + + srv := tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + sniHost = argHello.ServerName + return nil, nil + }, + }) + defer srv.Close() + + if err := srv.Handshake(); err != nil { + serverErrChan <- fmt.Errorf("handshake: %v", err) + return + } + + srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) + srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) + srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + + serverSNINameChan <- sniHost + }() + + port := strings.Split(ln.Addr().String(), ":")[1] + connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param) + _, err = pgconn.Connect(context.Background(), connStr) + + select { + case sniHost := <-serverSNINameChan: + if tt.sni_set { + require.Equal(t, sniHost, "localhost") + } else { + require.Equal(t, sniHost, "") + } + case err = <-serverErrChan: + t.Fatalf("server failed with error: %+v", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +}