mirror of https://github.com/jackc/pgx.git
add pgmock tests
parent
9b15554c51
commit
924834b5b4
|
@ -51,7 +51,7 @@ type Config struct {
|
|||
KerberosSpn string
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
Sslnegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||
SSLnegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
|
@ -389,7 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
config.Sslnegotiation = settings["sslnegotiation"]
|
||||
config.SSLnegotiation = settings["sslnegotiation"]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
|
|
|
@ -329,7 +329,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
|
|||
tlsConn net.Conn
|
||||
err error
|
||||
)
|
||||
if config.Sslnegotiation == "direct" {
|
||||
if config.SSLnegotiation == "direct" {
|
||||
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
|
||||
} else {
|
||||
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -3819,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
expectDirectNego bool
|
||||
}{
|
||||
{
|
||||
name: "Default negotiation (postgres)",
|
||||
connString: "sslmode=require",
|
||||
expectDirectNego: false,
|
||||
},
|
||||
{
|
||||
name: "Direct negotiation",
|
||||
connString: "sslmode=require sslnegotiation=direct",
|
||||
expectDirectNego: true,
|
||||
},
|
||||
{
|
||||
name: "Explicit postgres negotiation",
|
||||
connString: "sslmode=require sslnegotiation=postgres",
|
||||
expectDirectNego: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
var directNegoObserved atomic.Bool
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrCh)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("accept error: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
firstByte := make([]byte, 1)
|
||||
_, err = conn.Read(firstByte)
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
|
||||
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
|
||||
directNegoObserved.Store(isDirect)
|
||||
|
||||
var tlsConn *tls.Conn
|
||||
|
||||
if !isDirect {
|
||||
// Handle standard PostgreSQL SSL negotiation
|
||||
// Read the rest of the SSL request message
|
||||
sslRequestRemainder := make([]byte, 7)
|
||||
_, err = io.ReadFull(conn, sslRequestRemainder)
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send SSL acceptance response
|
||||
_, err = conn.Write([]byte("S"))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Setup TLS server without needing to reuse the first byte
|
||||
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
tlsConn = tls.Server(conn, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
} else {
|
||||
// Handle direct TLS negotiation
|
||||
// Setup TLS server with the first byte already read
|
||||
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use a wrapper to inject the first byte back into the TLS handshake
|
||||
bufConn := &prefixConn{
|
||||
Conn: conn,
|
||||
prefixData: firstByte,
|
||||
}
|
||||
|
||||
tlsConn = tls.Server(bufConn, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
}
|
||||
|
||||
// Complete TLS handshake
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
|
||||
return
|
||||
}
|
||||
defer tlsConn.Close()
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
|
||||
tt.connString, port)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, connStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close(ctx)
|
||||
|
||||
err = <-serverErrCh
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// prefixConn implements a net.Conn that prepends some data to the first Read
|
||||
type prefixConn struct {
|
||||
net.Conn
|
||||
prefixData []byte
|
||||
prefixConsumed bool
|
||||
}
|
||||
|
||||
func (c *prefixConn) Read(b []byte) (n int, err error) {
|
||||
if !c.prefixConsumed && len(c.prefixData) > 0 {
|
||||
n = copy(b, c.prefixData)
|
||||
c.prefixData = c.prefixData[n:]
|
||||
c.prefixConsumed = len(c.prefixData) == 0
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1920
|
||||
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
Loading…
Reference in New Issue