Set socket to non-blocking mode in `Read`, `Flush` and `BufferReadUntilBlock` operations

pull/1557/head
Dmitry K 2023-03-19 03:18:56 +03:00 committed by Jack Christensen
parent 3db7d1774e
commit b2b4fbcf57
3 changed files with 71 additions and 55 deletions

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"fmt"
"github.com/jackc/pgx/v5/internal/nbconn"
"io"
"github.com/jackc/pgx/v5/internal/pgio"
@ -132,17 +131,6 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
}
if realNbConn, ok := ct.conn.pgConn.Conn().(*nbconn.NetConn); ok {
if err := realNbConn.SetBlockingMode(false); err != nil {
return 0, fmt.Errorf("cannot set socket non-blocking mode: %w", err)
}
defer func() {
// TODO: Deal with it
_ = realNbConn.SetBlockingMode(true)
}()
}
r, w := io.Pipe()
doneChan := make(chan struct{})

View File

@ -13,6 +13,7 @@ package nbconn
import (
"crypto/tls"
"errors"
"fmt"
"net"
"os"
"sync"
@ -97,8 +98,8 @@ type NetConn struct {
writeDeadlineLock sync.Mutex
writeDeadline time.Time
// Indicates that underlying socket connection mode explicitly set to be non-blocking
isNonBlocking bool
// nbOperCnt Tracks how many operations performing simultaneously
nbOperCnt int32
}
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
@ -160,6 +161,18 @@ func (c *NetConn) Read(b []byte) (n int, err error) {
var readN int
if readNonblocking {
if setSockModeErr := c.SetBlockingMode(false); setSockModeErr != nil {
err = fmt.Errorf("cannot set socket to non-blocking mode: %w", setSockModeErr)
}
if err != nil {
return n, err
}
defer func() {
_ = c.SetBlockingMode(true)
}()
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.conn.Read(b[n:])
@ -284,6 +297,14 @@ func (c *NetConn) flush() error {
var stopChan chan struct{}
var errChan chan error
if err := c.SetBlockingMode(false); err != nil {
return fmt.Errorf("cannot set socket to non-blocking mode: %w", err)
}
defer func() {
_ = c.SetBlockingMode(true)
}()
defer func() {
if stopChan != nil {
select {
@ -327,6 +348,14 @@ func (c *NetConn) flush() error {
}
func (c *NetConn) BufferReadUntilBlock() error {
if err := c.SetBlockingMode(false); err != nil {
return fmt.Errorf("cannot set socket to non-blocking mode: %w", err)
}
defer func() {
_ = c.SetBlockingMode(true)
}()
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(*buf)

View File

@ -6,6 +6,7 @@ import (
"errors"
"golang.org/x/sys/windows"
"io"
"sync/atomic"
"syscall"
"unsafe"
)
@ -43,14 +44,6 @@ func setSockMode(fd uintptr, mode sockMode) error {
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
if c.nonblockWriteFunc == nil {
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
if !c.isNonBlocking {
// Make sock non-blocking
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
c.nonblockWriteErr = err
return true
}
}
var written uint32
var buf syscall.WSABuf
buf.Buf = &c.nonblockWriteBuf[0]
@ -58,14 +51,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil)
c.nonblockWriteN = int(written)
if !c.isNonBlocking {
// Make sock blocking again
if err := setSockMode(fd, sockModeBlocking); err != nil {
c.nonblockWriteErr = err
return true
}
}
return true
}
}
@ -98,14 +83,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
if c.nonblockReadFunc == nil {
c.nonblockReadFunc = func(fd uintptr) (done bool) {
if !c.isNonBlocking {
// Make sock non-blocking
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
c.nonblockReadErr = err
return true
}
}
var read uint32
var flags uint32
var buf syscall.WSABuf
@ -114,14 +91,6 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil)
c.nonblockReadN = int(read)
if !c.isNonBlocking {
// Make sock blocking again
if err := setSockMode(fd, sockModeBlocking); err != nil {
c.nonblockReadErr = err
return true
}
}
return true
}
}
@ -157,22 +126,52 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
}
func (c *NetConn) SetBlockingMode(blocking bool) error {
// Fake non-blocking I/O is ignored
if c.rawConn == nil {
return nil
}
if blocking {
// No ready to exit from non-blocking mode, there are pending non-blocking operations
if atomic.AddInt32(&c.nbOperCnt, -1) > 0 {
return nil
}
} else {
// Socket is already in non-blocking state
if atomic.AddInt32(&c.nbOperCnt, 1) > 1 {
return nil
}
//fmt.Println("socket reverting to blocking mode")
}
mode := sockModeNonBlocking
if blocking {
mode = sockModeBlocking
}
var err error
var ctrlErr, err error
if ctrlErr := c.rawConn.Control(func(fd uintptr) {
ctrlErr = c.rawConn.Control(func(fd uintptr) {
err = setSockMode(fd, mode)
}); ctrlErr != nil {
return ctrlErr
})
if ctrlErr != nil || err != nil {
// Revert counters inc/dec in case of error
if blocking {
atomic.AddInt32(&c.nbOperCnt, 1)
} else {
atomic.AddInt32(&c.nbOperCnt, -1)
}
if ctrlErr != nil {
return ctrlErr
}
if err != nil {
return err
}
}
if err == nil {
c.isNonBlocking = !blocking
}
return err
return nil
}