mirror of https://github.com/jackc/pgx.git
Add true non-blocking IO
parent
7dd26a34a1
commit
60ecdda02e
|
@ -17,6 +17,7 @@ import (
|
|||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
|
@ -54,7 +55,8 @@ type Conn interface {
|
|||
|
||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
type NetConn struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
rawConn syscall.RawConn
|
||||
|
||||
readQueue bufferQueue
|
||||
writeQueue bufferQueue
|
||||
|
@ -72,10 +74,20 @@ type NetConn struct {
|
|||
closed int64 // 0 = not closed, 1 = closed
|
||||
}
|
||||
|
||||
func NewNetConn(conn net.Conn) *NetConn {
|
||||
return &NetConn{
|
||||
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||
nc := &NetConn{
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
if !fakeNonBlockingIO {
|
||||
if sc, ok := conn.(syscall.Conn); ok {
|
||||
if rawConn, err := sc.SyscallConn(); err == nil {
|
||||
nc.rawConn = rawConn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
|
@ -323,7 +335,11 @@ func (c *NetConn) isClosed() bool {
|
|||
}
|
||||
|
||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
} else {
|
||||
return c.realNonblockingWrite(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||
|
@ -351,8 +367,37 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
|||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||
var funcErr error
|
||||
err = c.rawConn.Write(func(fd uintptr) (done bool) {
|
||||
n, funcErr = syscall.Write(int(fd), b)
|
||||
return true
|
||||
})
|
||||
if err == nil && funcErr != nil {
|
||||
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = funcErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||
return c.fakeNonblockingRead(b)
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingRead(b)
|
||||
} else {
|
||||
return c.realNonblockingRead(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||
|
@ -380,6 +425,31 @@ func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
|||
return c.conn.Read(b)
|
||||
}
|
||||
|
||||
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||
var funcErr error
|
||||
err = c.rawConn.Read(func(fd uintptr) (done bool) {
|
||||
n, funcErr = syscall.Read(int(fd), b)
|
||||
return true
|
||||
})
|
||||
if err == nil && funcErr != nil {
|
||||
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||
err = ErrWouldBlock
|
||||
} else {
|
||||
err = funcErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// n may be -1 when an error occurs.
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// syscall.Conn is interface
|
||||
|
||||
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||
|
|
|
@ -67,31 +67,53 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
|||
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||
useTLS bool
|
||||
name string
|
||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||
useTLS bool
|
||||
fakeNonBlockingIO bool
|
||||
}{
|
||||
{
|
||||
name: "Pipe",
|
||||
makeConns: makePipeConns,
|
||||
useTLS: false,
|
||||
name: "Pipe",
|
||||
makeConns: makePipeConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
name: "TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
name: "TLS over TCP with Fake Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: true,
|
||||
},
|
||||
{
|
||||
name: "TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: false,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
{
|
||||
name: "TLS over TCP with Real Non-blocking IO",
|
||||
makeConns: makeTCPConns,
|
||||
useTLS: true,
|
||||
fakeNonBlockingIO: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
local, remote := tt.makeConns(t)
|
||||
|
||||
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
|
||||
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
|
||||
// uses remote it may be garbage collected leading to the connection being closed.
|
||||
defer local.Close()
|
||||
defer remote.Close()
|
||||
|
||||
var conn nbconn.Conn
|
||||
netConn := nbconn.NewNetConn(local)
|
||||
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
|
||||
|
||||
if tt.useTLS {
|
||||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||
|
@ -244,6 +266,60 @@ func TestCloseFlushesWriteBuffer(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
|
||||
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
|
||||
// large values.
|
||||
func TestInternalNonBlockingWrite(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
remoteWriteBuf := make([]byte, deadlockSize)
|
||||
_, err := remote.Write(remoteWriteBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, <-errChan)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
|
||||
const deadlockSize = 4 * 1024 * 1024
|
||||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Flush()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||
|
|
|
@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
}
|
||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||
}
|
||||
netConn = nbconn.NewNetConn(netConn)
|
||||
netConn = nbconn.NewNetConn(netConn, false)
|
||||
|
||||
pgConn.conn = netConn
|
||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||
|
|
Loading…
Reference in New Issue