add pgmock tests

pull/2293/head
divyam234 2025-03-31 15:02:07 +02:00
parent 9b15554c51
commit 924834b5b4
No known key found for this signature in database
3 changed files with 171 additions and 3 deletions

View File

@ -51,7 +51,7 @@ type Config struct {
KerberosSpn string
Fallbacks []*FallbackConfig
Sslnegotiation string // sslnegotiation=postgres or sslnegotiation=direct
SSLnegotiation string // sslnegotiation=postgres or sslnegotiation=direct
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
@ -389,7 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
config.Sslnegotiation = settings["sslnegotiation"]
config.SSLnegotiation = settings["sslnegotiation"]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {

View File

@ -329,7 +329,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
tlsConn net.Conn
err error
)
if config.Sslnegotiation == "direct" {
if config.SSLnegotiation == "direct" {
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
} else {
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)

View File

@ -14,6 +14,7 @@ import (
"os"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -3819,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
}
}
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
connString string
expectDirectNego bool
}{
{
name: "Default negotiation (postgres)",
connString: "sslmode=require",
expectDirectNego: false,
},
{
name: "Direct negotiation",
connString: "sslmode=require sslnegotiation=direct",
expectDirectNego: true,
},
{
name: "Explicit postgres negotiation",
connString: "sslmode=require sslnegotiation=postgres",
expectDirectNego: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
_, port, err := net.SplitHostPort(ln.Addr().String())
require.NoError(t, err)
var directNegoObserved atomic.Bool
serverErrCh := make(chan error, 1)
go func() {
defer close(serverErrCh)
conn, err := ln.Accept()
if err != nil {
serverErrCh <- fmt.Errorf("accept error: %w", err)
return
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second))
firstByte := make([]byte, 1)
_, err = conn.Read(firstByte)
if err != nil {
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
return
}
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
directNegoObserved.Store(isDirect)
var tlsConn *tls.Conn
if !isDirect {
// Handle standard PostgreSQL SSL negotiation
// Read the rest of the SSL request message
sslRequestRemainder := make([]byte, 7)
_, err = io.ReadFull(conn, sslRequestRemainder)
if err != nil {
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
return
}
// Send SSL acceptance response
_, err = conn.Write([]byte("S"))
if err != nil {
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
return
}
// Setup TLS server without needing to reuse the first byte
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
if err != nil {
serverErrCh <- fmt.Errorf("cert error: %w", err)
return
}
tlsConn = tls.Server(conn, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else {
// Handle direct TLS negotiation
// Setup TLS server with the first byte already read
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
if err != nil {
serverErrCh <- fmt.Errorf("cert error: %w", err)
return
}
// Use a wrapper to inject the first byte back into the TLS handshake
bufConn := &prefixConn{
Conn: conn,
prefixData: firstByte,
}
tlsConn = tls.Server(bufConn, &tls.Config{
Certificates: []tls.Certificate{cert},
})
}
// Complete TLS handshake
if err := tlsConn.Handshake(); err != nil {
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
return
}
defer tlsConn.Close()
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
if err != nil {
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
return
}
}()
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
tt.connString, port)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
conn, err := pgconn.Connect(ctx, connStr)
require.NoError(t, err)
defer conn.Close(ctx)
err = <-serverErrCh
require.NoError(t, err)
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
})
}
}
// prefixConn implements a net.Conn that prepends some data to the first Read
type prefixConn struct {
net.Conn
prefixData []byte
prefixConsumed bool
}
func (c *prefixConn) Read(b []byte) (n int, err error) {
if !c.prefixConsumed && len(c.prefixData) > 0 {
n = copy(b, c.prefixData)
c.prefixData = c.prefixData[n:]
c.prefixConsumed = len(c.prefixData) == 0
return n, nil
}
return c.Conn.Read(b)
}
// https://github.com/jackc/pgx/issues/1920
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
t.Parallel()