diff --git a/pgconn/config.go b/pgconn/config.go index 926cb980..f12dfc1e 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -51,7 +51,7 @@ type Config struct { KerberosSpn string Fallbacks []*FallbackConfig - Sslnegotiation string // sslnegotiation=postgres or sslnegotiation=direct + SSLnegotiation string // sslnegotiation=postgres or sslnegotiation=direct // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next @@ -389,7 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con config.Port = fallbacks[0].Port config.TLSConfig = fallbacks[0].TLSConfig config.Fallbacks = fallbacks[1:] - config.Sslnegotiation = settings["sslnegotiation"] + config.SSLnegotiation = settings["sslnegotiation"] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index fb2dd191..26a590ca 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -329,7 +329,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo tlsConn net.Conn err error ) - if config.Sslnegotiation == "direct" { + if config.SSLnegotiation == "direct" { tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig) } else { tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index b2d2f7f7..f1930479 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -14,6 +14,7 @@ import ( "os" "strconv" "strings" + "sync/atomic" "testing" "time" @@ -3819,6 +3820,173 @@ func TestSNISupport(t *testing.T) { } } +func TestConnectWithDirectSSLNegotiation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connString string + expectDirectNego bool + }{ + { + name: "Default negotiation (postgres)", + connString: "sslmode=require", + expectDirectNego: false, + }, + { + name: "Direct negotiation", + connString: "sslmode=require sslnegotiation=direct", + expectDirectNego: true, + }, + { + name: "Explicit postgres negotiation", + connString: "sslmode=require sslnegotiation=postgres", + expectDirectNego: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + var directNegoObserved atomic.Bool + + serverErrCh := make(chan error, 1) + go func() { + defer close(serverErrCh) + + conn, err := ln.Accept() + if err != nil { + serverErrCh <- fmt.Errorf("accept error: %w", err) + return + } + defer conn.Close() + + conn.SetDeadline(time.Now().Add(5 * time.Second)) + + firstByte := make([]byte, 1) + _, err = conn.Read(firstByte) + if err != nil { + serverErrCh <- fmt.Errorf("read first byte error: %w", err) + return + } + + // Check if TLS Client Hello (direct) or PostgreSQL SSLRequest + isDirect := firstByte[0] >= 20 && firstByte[0] <= 23 + directNegoObserved.Store(isDirect) + + var tlsConn *tls.Conn + + if !isDirect { + // Handle standard PostgreSQL SSL negotiation + // Read the rest of the SSL request message + sslRequestRemainder := make([]byte, 7) + _, err = io.ReadFull(conn, sslRequestRemainder) + if err != nil { + serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err) + return + } + + // Send SSL acceptance response + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err) + return + } + + // Setup TLS server without needing to reuse the first byte + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrCh <- fmt.Errorf("cert error: %w", err) + return + } + + tlsConn = tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } else { + // Handle direct TLS negotiation + // Setup TLS server with the first byte already read + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrCh <- fmt.Errorf("cert error: %w", err) + return + } + + // Use a wrapper to inject the first byte back into the TLS handshake + bufConn := &prefixConn{ + Conn: conn, + prefixData: firstByte, + } + + tlsConn = tls.Server(bufConn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } + + // Complete TLS handshake + if err := tlsConn.Handshake(); err != nil { + serverErrCh <- fmt.Errorf("TLS handshake error: %w", err) + return + } + defer tlsConn.Close() + + err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn)) + if err != nil { + serverErrCh <- fmt.Errorf("pgmock run error: %w", err) + return + } + }() + + connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1", + tt.connString, port) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, connStr) + + require.NoError(t, err) + + defer conn.Close(ctx) + + err = <-serverErrCh + require.NoError(t, err) + + require.Equal(t, tt.expectDirectNego, directNegoObserved.Load()) + }) + } +} + +// prefixConn implements a net.Conn that prepends some data to the first Read +type prefixConn struct { + net.Conn + prefixData []byte + prefixConsumed bool +} + +func (c *prefixConn) Read(b []byte) (n int, err error) { + if !c.prefixConsumed && len(c.prefixData) > 0 { + n = copy(b, c.prefixData) + c.prefixData = c.prefixData[n:] + c.prefixConsumed = len(c.prefixData) == 0 + return n, nil + } + return c.Conn.Read(b) +} + // https://github.com/jackc/pgx/issues/1920 func TestFatalErrorReceivedInPipelineMode(t *testing.T) { t.Parallel()