mirror of https://github.com/jackc/pgx.git
add pgmock tests
parent
9b15554c51
commit
924834b5b4
|
@ -51,7 +51,7 @@ type Config struct {
|
||||||
KerberosSpn string
|
KerberosSpn string
|
||||||
Fallbacks []*FallbackConfig
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
Sslnegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
SSLnegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||||
|
|
||||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||||
|
@ -389,7 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||||
config.Port = fallbacks[0].Port
|
config.Port = fallbacks[0].Port
|
||||||
config.TLSConfig = fallbacks[0].TLSConfig
|
config.TLSConfig = fallbacks[0].TLSConfig
|
||||||
config.Fallbacks = fallbacks[1:]
|
config.Fallbacks = fallbacks[1:]
|
||||||
config.Sslnegotiation = settings["sslnegotiation"]
|
config.SSLnegotiation = settings["sslnegotiation"]
|
||||||
|
|
||||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -329,7 +329,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
|
||||||
tlsConn net.Conn
|
tlsConn net.Conn
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if config.Sslnegotiation == "direct" {
|
if config.SSLnegotiation == "direct" {
|
||||||
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
|
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
|
||||||
} else {
|
} else {
|
||||||
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
|
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -3819,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connString string
|
||||||
|
expectDirectNego bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Default negotiation (postgres)",
|
||||||
|
connString: "sslmode=require",
|
||||||
|
expectDirectNego: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Direct negotiation",
|
||||||
|
connString: "sslmode=require sslnegotiation=direct",
|
||||||
|
expectDirectNego: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Explicit postgres negotiation",
|
||||||
|
connString: "sslmode=require sslnegotiation=postgres",
|
||||||
|
expectDirectNego: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
script := &pgmock.Script{
|
||||||
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var directNegoObserved atomic.Bool
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(serverErrCh)
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("accept error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
firstByte := make([]byte, 1)
|
||||||
|
_, err = conn.Read(firstByte)
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
|
||||||
|
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
|
||||||
|
directNegoObserved.Store(isDirect)
|
||||||
|
|
||||||
|
var tlsConn *tls.Conn
|
||||||
|
|
||||||
|
if !isDirect {
|
||||||
|
// Handle standard PostgreSQL SSL negotiation
|
||||||
|
// Read the rest of the SSL request message
|
||||||
|
sslRequestRemainder := make([]byte, 7)
|
||||||
|
_, err = io.ReadFull(conn, sslRequestRemainder)
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send SSL acceptance response
|
||||||
|
_, err = conn.Write([]byte("S"))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup TLS server without needing to reuse the first byte
|
||||||
|
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn = tls.Server(conn, &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Handle direct TLS negotiation
|
||||||
|
// Setup TLS server with the first byte already read
|
||||||
|
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a wrapper to inject the first byte back into the TLS handshake
|
||||||
|
bufConn := &prefixConn{
|
||||||
|
Conn: conn,
|
||||||
|
prefixData: firstByte,
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn = tls.Server(bufConn, &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete TLS handshake
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tlsConn.Close()
|
||||||
|
|
||||||
|
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
|
||||||
|
tt.connString, port)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, connStr)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
err = <-serverErrCh
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixConn implements a net.Conn that prepends some data to the first Read
|
||||||
|
type prefixConn struct {
|
||||||
|
net.Conn
|
||||||
|
prefixData []byte
|
||||||
|
prefixConsumed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *prefixConn) Read(b []byte) (n int, err error) {
|
||||||
|
if !c.prefixConsumed && len(c.prefixData) > 0 {
|
||||||
|
n = copy(b, c.prefixData)
|
||||||
|
c.prefixData = c.prefixData[n:]
|
||||||
|
c.prefixConsumed = len(c.prefixData) == 0
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return c.Conn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/1920
|
// https://github.com/jackc/pgx/issues/1920
|
||||||
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
Loading…
Reference in New Issue