mirror of https://github.com/jackc/pgx.git
427 lines
11 KiB
Go
427 lines
11 KiB
Go
package nbconn_test
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/internal/nbconn"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// Test keys generated with:
|
|
//
|
|
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
|
|
|
|
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
|
|
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
|
|
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
|
|
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
|
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
|
|
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
|
|
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
|
|
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
|
|
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
|
|
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
|
|
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
|
|
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
|
|
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
|
|
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
|
|
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
|
|
E7ZYmaKTMOhvkg==
|
|
-----END CERTIFICATE-----`)
|
|
|
|
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
|
|
// source code.
|
|
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
|
|
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
|
|
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
|
|
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
|
|
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
|
|
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
|
|
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
|
|
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
|
|
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
|
|
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
|
|
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
|
|
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
|
|
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
|
|
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
|
|
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
|
|
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
|
|
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
|
|
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
|
|
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
|
|
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
|
|
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
|
|
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
|
|
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
|
|
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
|
|
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
|
|
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
|
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
|
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
|
|
|
func testVariants(t *testing.T, f func(t *testing.T, local *nbconn.Conn, remote net.Conn)) {
|
|
for _, tt := range []struct {
|
|
name string
|
|
makeConns func(t *testing.T) (local, remote net.Conn)
|
|
useTLS bool
|
|
}{
|
|
{
|
|
name: "Pipe",
|
|
makeConns: makePipeConns,
|
|
useTLS: false,
|
|
},
|
|
{
|
|
name: "TCP",
|
|
makeConns: makeTCPConns,
|
|
useTLS: false,
|
|
},
|
|
{
|
|
name: "TLS over TCP",
|
|
makeConns: makeTCPConns,
|
|
useTLS: true,
|
|
},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
local, remote := tt.makeConns(t)
|
|
conn := nbconn.New(local)
|
|
|
|
if tt.useTLS {
|
|
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
|
require.NoError(t, err)
|
|
|
|
remote = tls.Server(remote, &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
})
|
|
conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
|
}
|
|
|
|
f(t, conn, remote)
|
|
})
|
|
}
|
|
}
|
|
|
|
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
|
|
// useful for testing an exact sequence of reads and writes.
|
|
func makePipeConns(t *testing.T) (local, remote net.Conn) {
|
|
local, remote = net.Pipe()
|
|
t.Cleanup(func() {
|
|
local.Close()
|
|
remote.Close()
|
|
})
|
|
|
|
return local, remote
|
|
}
|
|
|
|
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
|
|
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
type acceptResultT struct {
|
|
conn net.Conn
|
|
err error
|
|
}
|
|
acceptChan := make(chan acceptResultT)
|
|
|
|
go func() {
|
|
conn, err := ln.Accept()
|
|
acceptChan <- acceptResultT{conn: conn, err: err}
|
|
}()
|
|
|
|
local, err = net.Dial("tcp", ln.Addr().String())
|
|
require.NoError(t, err)
|
|
|
|
acceptResult := <-acceptChan
|
|
require.NoError(t, acceptResult.err)
|
|
|
|
remote = acceptResult.conn
|
|
|
|
return local, remote
|
|
}
|
|
|
|
func TestWriteIsBuffered(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
// net.Pipe is synchronous so the Write would block if not buffered.
|
|
writeBuf := []byte("test")
|
|
n, err := conn.Write(writeBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 4, n)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
err := conn.Flush()
|
|
errChan <- err
|
|
}()
|
|
|
|
readBuf := make([]byte, len(writeBuf))
|
|
_, err = remote.Read(readBuf)
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, <-errChan)
|
|
})
|
|
}
|
|
|
|
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
err := conn.SetWriteDeadline(time.Now())
|
|
require.NoError(t, err)
|
|
|
|
writeBuf := []byte("test")
|
|
n, err := conn.Write(writeBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 4, n)
|
|
})
|
|
}
|
|
|
|
func TestReadFlushesWriteBuffer(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
writeBuf := []byte("test")
|
|
n, err := conn.Write(writeBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 4, n)
|
|
|
|
errChan := make(chan error, 2)
|
|
go func() {
|
|
readBuf := make([]byte, len(writeBuf))
|
|
_, err := remote.Read(readBuf)
|
|
errChan <- err
|
|
|
|
_, err = remote.Write([]byte("okay"))
|
|
errChan <- err
|
|
}()
|
|
|
|
readBuf := make([]byte, 4)
|
|
_, err = conn.Read(readBuf)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte("okay"), readBuf)
|
|
|
|
require.NoError(t, <-errChan)
|
|
require.NoError(t, <-errChan)
|
|
})
|
|
}
|
|
|
|
func TestCloseFlushesWriteBuffer(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
writeBuf := []byte("test")
|
|
n, err := conn.Write(writeBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 4, n)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
readBuf := make([]byte, len(writeBuf))
|
|
_, err := remote.Read(readBuf)
|
|
errChan <- err
|
|
}()
|
|
|
|
err = conn.Close()
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, <-errChan)
|
|
})
|
|
}
|
|
|
|
func TestNonBlockingRead(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
|
require.NoError(t, err)
|
|
|
|
buf := make([]byte, 4)
|
|
n, err := conn.Read(buf)
|
|
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
|
|
require.EqualValues(t, 0, n)
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
_, err := remote.Write([]byte("okay"))
|
|
errChan <- err
|
|
}()
|
|
|
|
err = conn.SetReadDeadline(time.Time{})
|
|
require.NoError(t, err)
|
|
|
|
n, err = conn.Read(buf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 4, n)
|
|
})
|
|
}
|
|
|
|
func TestReadPreviouslyBuffered(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
err := func() error {
|
|
_, err := remote.Write([]byte("alpha"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
readBuf := make([]byte, 4)
|
|
_, err = remote.Read(readBuf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
errChan <- err
|
|
}()
|
|
|
|
_, err := conn.Write([]byte("test"))
|
|
require.NoError(t, err)
|
|
|
|
// Because net.Pipe() is synchronous conn.Flust must buffer a read.
|
|
err = conn.Flush()
|
|
require.NoError(t, err)
|
|
|
|
readBuf := make([]byte, 5)
|
|
n, err := conn.Read(readBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 5, n)
|
|
require.Equal(t, []byte("alpha"), readBuf)
|
|
})
|
|
}
|
|
|
|
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
err := func() error {
|
|
_, err := remote.Write([]byte("alpha"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
readBuf := make([]byte, 4)
|
|
_, err = remote.Read(readBuf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
errChan <- err
|
|
}()
|
|
|
|
_, err := conn.Write([]byte("test"))
|
|
require.NoError(t, err)
|
|
|
|
// Because net.Pipe() is synchronous conn.Flust must buffer a read.
|
|
err = conn.Flush()
|
|
require.NoError(t, err)
|
|
|
|
readBuf := make([]byte, 2)
|
|
n, err := conn.Read(readBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 2, n)
|
|
require.Equal(t, []byte("al"), readBuf)
|
|
|
|
readBuf = make([]byte, 3)
|
|
n, err = conn.Read(readBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 3, n)
|
|
require.Equal(t, []byte("pha"), readBuf)
|
|
})
|
|
}
|
|
|
|
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
err := func() error {
|
|
_, err := remote.Write([]byte("alpha"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = remote.Write([]byte("beta"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
readBuf := make([]byte, 4)
|
|
_, err = remote.Read(readBuf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
errChan <- err
|
|
}()
|
|
|
|
_, err := conn.Write([]byte("test"))
|
|
require.NoError(t, err)
|
|
|
|
// Because net.Pipe() is synchronous conn.Flust must buffer a read.
|
|
err = conn.Flush()
|
|
require.NoError(t, err)
|
|
|
|
readBuf := make([]byte, 9)
|
|
n, err := conn.Read(readBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 9, n)
|
|
require.Equal(t, []byte("alphabeta"), readBuf)
|
|
})
|
|
}
|
|
|
|
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
|
|
testVariants(t, func(t *testing.T, conn *nbconn.Conn, remote net.Conn) {
|
|
|
|
flushCompleteChan := make(chan struct{})
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
err := func() error {
|
|
_, err := remote.Write([]byte("alpha"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
readBuf := make([]byte, 4)
|
|
_, err = remote.Read(readBuf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
<-flushCompleteChan
|
|
|
|
_, err = remote.Write([]byte("beta"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
errChan <- err
|
|
}()
|
|
|
|
_, err := conn.Write([]byte("test"))
|
|
require.NoError(t, err)
|
|
|
|
// Because net.Pipe() is synchronous conn.Flust must buffer a read.
|
|
err = conn.Flush()
|
|
require.NoError(t, err)
|
|
|
|
close(flushCompleteChan)
|
|
|
|
readBuf := make([]byte, 9)
|
|
|
|
n, err := io.ReadFull(conn, readBuf)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 9, n)
|
|
require.Equal(t, []byte("alphabeta"), readBuf)
|
|
|
|
err = <-errChan
|
|
require.NoError(t, err)
|
|
})
|
|
}
|