diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index a4fead7e..97a5844f 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -17,6 +17,7 @@ import ( "os" "sync" "sync/atomic" + "syscall" "time" "github.com/jackc/pgx/v5/internal/iobufpool" @@ -54,7 +55,8 @@ type Conn interface { // NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. type NetConn struct { - conn net.Conn + conn net.Conn + rawConn syscall.RawConn readQueue bufferQueue writeQueue bufferQueue @@ -72,10 +74,20 @@ type NetConn struct { closed int64 // 0 = not closed, 1 = closed } -func NewNetConn(conn net.Conn) *NetConn { - return &NetConn{ +func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { + nc := &NetConn{ conn: conn, } + + if !fakeNonBlockingIO { + if sc, ok := conn.(syscall.Conn); ok { + if rawConn, err := sc.SyscallConn(); err == nil { + nc.rawConn = rawConn + } + } + } + + return nc } // Read implements io.Reader. @@ -323,7 +335,11 @@ func (c *NetConn) isClosed() bool { } func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { - return c.fakeNonblockingWrite(b) + if c.rawConn == nil { + return c.fakeNonblockingWrite(b) + } else { + return c.realNonblockingWrite(b) + } } func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { @@ -351,8 +367,37 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { return c.conn.Write(b) } +func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { + var funcErr error + err = c.rawConn.Write(func(fd uintptr) (done bool) { + n, funcErr = syscall.Write(int(fd), b) + return true + }) + if err == nil && funcErr != nil { + if errors.Is(funcErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = funcErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { - return c.fakeNonblockingRead(b) + if c.rawConn == nil { + return c.fakeNonblockingRead(b) + } else { + return c.realNonblockingRead(b) + } } func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { @@ -380,6 +425,31 @@ func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { return c.conn.Read(b) } +func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { + var funcErr error + err = c.rawConn.Read(func(fd uintptr) (done bool) { + n, funcErr = syscall.Read(int(fd), b) + return true + }) + if err == nil && funcErr != nil { + if errors.Is(funcErr, syscall.EWOULDBLOCK) { + err = ErrWouldBlock + } else { + err = funcErr + } + } + if err != nil { + // n may be -1 when an error occurs. + if n < 0 { + n = 0 + } + + return n, err + } + + return n, nil +} + // syscall.Conn is interface // TLSClient establishes a TLS connection as a client over conn using config. diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index f26ccea3..2db47039 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -67,31 +67,53 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) { for _, tt := range []struct { - name string - makeConns func(t *testing.T) (local, remote net.Conn) - useTLS bool + name string + makeConns func(t *testing.T) (local, remote net.Conn) + useTLS bool + fakeNonBlockingIO bool }{ { - name: "Pipe", - makeConns: makePipeConns, - useTLS: false, + name: "Pipe", + makeConns: makePipeConns, + useTLS: false, + fakeNonBlockingIO: true, }, { - name: "TCP", - makeConns: makeTCPConns, - useTLS: false, + name: "TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: true, }, { - name: "TLS over TCP", - makeConns: makeTCPConns, - useTLS: true, + name: "TLS over TCP with Fake Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: true, + }, + { + name: "TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: false, + fakeNonBlockingIO: false, + }, + { + name: "TLS over TCP with Real Non-blocking IO", + makeConns: makeTCPConns, + useTLS: true, + fakeNonBlockingIO: false, }, } { t.Run(tt.name, func(t *testing.T) { local, remote := tt.makeConns(t) + // Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get + // garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never + // uses remote it may be garbage collected leading to the connection being closed. + defer local.Close() + defer remote.Close() + var conn nbconn.Conn - netConn := nbconn.NewNetConn(local) + netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO) if tt.useTLS { cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) @@ -244,6 +266,60 @@ func TestCloseFlushesWriteBuffer(t *testing.T) { }) } +// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with +// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing +// large values. +func TestInternalNonBlockingWrite(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + errChan := make(chan error, 1) + go func() { + remoteWriteBuf := make([]byte, deadlockSize) + _, err := remote.Write(remoteWriteBuf) + if err != nil { + errChan <- err + return + } + + readBuf := make([]byte, deadlockSize) + _, err = io.ReadFull(remote, readBuf) + errChan <- err + }() + + readBuf := make([]byte, deadlockSize) + _, err = conn.Read(readBuf) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) + }) +} + +func TestInternalNonBlockingWriteWithDeadline(t *testing.T) { + const deadlockSize = 4 * 1024 * 1024 + + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + writeBuf := make([]byte, deadlockSize) + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, deadlockSize, n) + + err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err) + + err = conn.Flush() + require.Error(t, err) + }) +} + func TestNonBlockingRead(t *testing.T) { testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { err := conn.SetReadDeadline(nbconn.NonBlockingDeadline) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 5a5fac82..f5f558b9 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } return nil, &connectError{config: config, msg: "dial error", err: err} } - netConn = nbconn.NewNetConn(netConn) + netConn = nbconn.NewNetConn(netConn, false) pgConn.conn = netConn pgConn.contextWatcher = newContextWatcher(netConn)