mirror of https://github.com/jackc/pgx.git
Add true non-blocking IO
parent
7dd26a34a1
commit
60ecdda02e
|
@ -17,6 +17,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
|
@ -54,7 +55,8 @@ type Conn interface {
|
||||||
|
|
||||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||||
type NetConn struct {
|
type NetConn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
rawConn syscall.RawConn
|
||||||
|
|
||||||
readQueue bufferQueue
|
readQueue bufferQueue
|
||||||
writeQueue bufferQueue
|
writeQueue bufferQueue
|
||||||
|
@ -72,10 +74,20 @@ type NetConn struct {
|
||||||
closed int64 // 0 = not closed, 1 = closed
|
closed int64 // 0 = not closed, 1 = closed
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetConn(conn net.Conn) *NetConn {
|
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||||
return &NetConn{
|
nc := &NetConn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !fakeNonBlockingIO {
|
||||||
|
if sc, ok := conn.(syscall.Conn); ok {
|
||||||
|
if rawConn, err := sc.SyscallConn(); err == nil {
|
||||||
|
nc.rawConn = rawConn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read implements io.Reader.
|
// Read implements io.Reader.
|
||||||
|
@ -323,7 +335,11 @@ func (c *NetConn) isClosed() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||||
return c.fakeNonblockingWrite(b)
|
if c.rawConn == nil {
|
||||||
|
return c.fakeNonblockingWrite(b)
|
||||||
|
} else {
|
||||||
|
return c.realNonblockingWrite(b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
@ -351,8 +367,37 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
||||||
return c.conn.Write(b)
|
return c.conn.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||||
|
var funcErr error
|
||||||
|
err = c.rawConn.Write(func(fd uintptr) (done bool) {
|
||||||
|
n, funcErr = syscall.Write(int(fd), b)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err == nil && funcErr != nil {
|
||||||
|
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||||
|
err = ErrWouldBlock
|
||||||
|
} else {
|
||||||
|
err = funcErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// n may be -1 when an error occurs.
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||||
return c.fakeNonblockingRead(b)
|
if c.rawConn == nil {
|
||||||
|
return c.fakeNonblockingRead(b)
|
||||||
|
} else {
|
||||||
|
return c.realNonblockingRead(b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||||
|
@ -380,6 +425,31 @@ func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
|
||||||
return c.conn.Read(b)
|
return c.conn.Read(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||||
|
var funcErr error
|
||||||
|
err = c.rawConn.Read(func(fd uintptr) (done bool) {
|
||||||
|
n, funcErr = syscall.Read(int(fd), b)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err == nil && funcErr != nil {
|
||||||
|
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
|
||||||
|
err = ErrWouldBlock
|
||||||
|
} else {
|
||||||
|
err = funcErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// n may be -1 when an error occurs.
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
// syscall.Conn is interface
|
// syscall.Conn is interface
|
||||||
|
|
||||||
// TLSClient establishes a TLS connection as a client over conn using config.
|
// TLSClient establishes a TLS connection as a client over conn using config.
|
||||||
|
|
|
@ -67,31 +67,53 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
||||||
|
|
||||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
makeConns func(t *testing.T) (local, remote net.Conn)
|
makeConns func(t *testing.T) (local, remote net.Conn)
|
||||||
useTLS bool
|
useTLS bool
|
||||||
|
fakeNonBlockingIO bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Pipe",
|
name: "Pipe",
|
||||||
makeConns: makePipeConns,
|
makeConns: makePipeConns,
|
||||||
useTLS: false,
|
useTLS: false,
|
||||||
|
fakeNonBlockingIO: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "TCP",
|
name: "TCP with Fake Non-blocking IO",
|
||||||
makeConns: makeTCPConns,
|
makeConns: makeTCPConns,
|
||||||
useTLS: false,
|
useTLS: false,
|
||||||
|
fakeNonBlockingIO: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "TLS over TCP",
|
name: "TLS over TCP with Fake Non-blocking IO",
|
||||||
makeConns: makeTCPConns,
|
makeConns: makeTCPConns,
|
||||||
useTLS: true,
|
useTLS: true,
|
||||||
|
fakeNonBlockingIO: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP with Real Non-blocking IO",
|
||||||
|
makeConns: makeTCPConns,
|
||||||
|
useTLS: false,
|
||||||
|
fakeNonBlockingIO: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TLS over TCP with Real Non-blocking IO",
|
||||||
|
makeConns: makeTCPConns,
|
||||||
|
useTLS: true,
|
||||||
|
fakeNonBlockingIO: false,
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
local, remote := tt.makeConns(t)
|
local, remote := tt.makeConns(t)
|
||||||
|
|
||||||
|
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
|
||||||
|
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
|
||||||
|
// uses remote it may be garbage collected leading to the connection being closed.
|
||||||
|
defer local.Close()
|
||||||
|
defer remote.Close()
|
||||||
|
|
||||||
var conn nbconn.Conn
|
var conn nbconn.Conn
|
||||||
netConn := nbconn.NewNetConn(local)
|
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
|
||||||
|
|
||||||
if tt.useTLS {
|
if tt.useTLS {
|
||||||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||||
|
@ -244,6 +266,60 @@ func TestCloseFlushesWriteBuffer(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
|
||||||
|
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
|
||||||
|
// large values.
|
||||||
|
func TestInternalNonBlockingWrite(t *testing.T) {
|
||||||
|
const deadlockSize = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
writeBuf := make([]byte, deadlockSize)
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, deadlockSize, n)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
remoteWriteBuf := make([]byte, deadlockSize)
|
||||||
|
_, err := remote.Write(remoteWriteBuf)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
readBuf := make([]byte, deadlockSize)
|
||||||
|
_, err = io.ReadFull(remote, readBuf)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuf := make([]byte, deadlockSize)
|
||||||
|
_, err = conn.Read(readBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, <-errChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
|
||||||
|
const deadlockSize = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
writeBuf := make([]byte, deadlockSize)
|
||||||
|
n, err := conn.Write(writeBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, deadlockSize, n)
|
||||||
|
|
||||||
|
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = conn.Flush()
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestNonBlockingRead(t *testing.T) {
|
func TestNonBlockingRead(t *testing.T) {
|
||||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
|
||||||
|
|
|
@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
}
|
}
|
||||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||||
}
|
}
|
||||||
netConn = nbconn.NewNetConn(netConn)
|
netConn = nbconn.NewNetConn(netConn, false)
|
||||||
|
|
||||||
pgConn.conn = netConn
|
pgConn.conn = netConn
|
||||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||||
|
|
Loading…
Reference in New Issue