diff --git a/copy_from.go b/copy_from.go index e87bb94a..a2c227fd 100644 --- a/copy_from.go +++ b/copy_from.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "github.com/jackc/pgx/v5/internal/nbconn" "io" "github.com/jackc/pgx/v5/internal/pgio" @@ -132,17 +131,6 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) } - if realNbConn, ok := ct.conn.pgConn.Conn().(*nbconn.NetConn); ok { - if err := realNbConn.SetBlockingMode(false); err != nil { - return 0, fmt.Errorf("cannot set socket non-blocking mode: %w", err) - } - - defer func() { - // TODO: Deal with it - _ = realNbConn.SetBlockingMode(true) - }() - } - r, w := io.Pipe() doneChan := make(chan struct{}) diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 534ec605..a0dd5d4c 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -13,6 +13,7 @@ package nbconn import ( "crypto/tls" "errors" + "fmt" "net" "os" "sync" @@ -97,8 +98,8 @@ type NetConn struct { writeDeadlineLock sync.Mutex writeDeadline time.Time - // Indicates that underlying socket connection mode explicitly set to be non-blocking - isNonBlocking bool + // nbOperCnt Tracks how many operations performing simultaneously + nbOperCnt int32 } func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { @@ -160,6 +161,18 @@ func (c *NetConn) Read(b []byte) (n int, err error) { var readN int if readNonblocking { + if setSockModeErr := c.SetBlockingMode(false); setSockModeErr != nil { + err = fmt.Errorf("cannot set socket to non-blocking mode: %w", setSockModeErr) + } + + if err != nil { + return n, err + } + + defer func() { + _ = c.SetBlockingMode(true) + }() + readN, err = c.nonblockingRead(b[n:]) } else { readN, err = c.conn.Read(b[n:]) @@ -284,6 +297,14 @@ func (c *NetConn) flush() error { var stopChan chan struct{} var errChan chan error + if err := c.SetBlockingMode(false); err != nil { + return fmt.Errorf("cannot set socket to non-blocking mode: %w", err) + } + + defer func() { + _ = c.SetBlockingMode(true) + }() + defer func() { if stopChan != nil { select { @@ -327,6 +348,14 @@ func (c *NetConn) flush() error { } func (c *NetConn) BufferReadUntilBlock() error { + if err := c.SetBlockingMode(false); err != nil { + return fmt.Errorf("cannot set socket to non-blocking mode: %w", err) + } + + defer func() { + _ = c.SetBlockingMode(true) + }() + for { buf := iobufpool.Get(8 * 1024) n, err := c.nonblockingRead(*buf) diff --git a/internal/nbconn/nbconn_real_non_block_windows.go b/internal/nbconn/nbconn_real_non_block_windows.go index 8ee2b5ae..eb7c3f35 100644 --- a/internal/nbconn/nbconn_real_non_block_windows.go +++ b/internal/nbconn/nbconn_real_non_block_windows.go @@ -6,6 +6,7 @@ import ( "errors" "golang.org/x/sys/windows" "io" + "sync/atomic" "syscall" "unsafe" ) @@ -43,14 +44,6 @@ func setSockMode(fd uintptr, mode sockMode) error { func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { if c.nonblockWriteFunc == nil { c.nonblockWriteFunc = func(fd uintptr) (done bool) { - if !c.isNonBlocking { - // Make sock non-blocking - if err := setSockMode(fd, sockModeNonBlocking); err != nil { - c.nonblockWriteErr = err - return true - } - } - var written uint32 var buf syscall.WSABuf buf.Buf = &c.nonblockWriteBuf[0] @@ -58,14 +51,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil) c.nonblockWriteN = int(written) - if !c.isNonBlocking { - // Make sock blocking again - if err := setSockMode(fd, sockModeBlocking); err != nil { - c.nonblockWriteErr = err - return true - } - } - return true } } @@ -98,14 +83,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { if c.nonblockReadFunc == nil { c.nonblockReadFunc = func(fd uintptr) (done bool) { - if !c.isNonBlocking { - // Make sock non-blocking - if err := setSockMode(fd, sockModeNonBlocking); err != nil { - c.nonblockReadErr = err - return true - } - } - var read uint32 var flags uint32 var buf syscall.WSABuf @@ -114,14 +91,6 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil) c.nonblockReadN = int(read) - if !c.isNonBlocking { - // Make sock blocking again - if err := setSockMode(fd, sockModeBlocking); err != nil { - c.nonblockReadErr = err - return true - } - } - return true } } @@ -157,22 +126,52 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { } func (c *NetConn) SetBlockingMode(blocking bool) error { + // Fake non-blocking I/O is ignored + if c.rawConn == nil { + return nil + } + + if blocking { + // No ready to exit from non-blocking mode, there are pending non-blocking operations + if atomic.AddInt32(&c.nbOperCnt, -1) > 0 { + return nil + } + } else { + // Socket is already in non-blocking state + if atomic.AddInt32(&c.nbOperCnt, 1) > 1 { + return nil + } + + //fmt.Println("socket reverting to blocking mode") + } + mode := sockModeNonBlocking if blocking { mode = sockModeBlocking } - var err error + var ctrlErr, err error - if ctrlErr := c.rawConn.Control(func(fd uintptr) { + ctrlErr = c.rawConn.Control(func(fd uintptr) { err = setSockMode(fd, mode) - }); ctrlErr != nil { - return ctrlErr + }) + + if ctrlErr != nil || err != nil { + // Revert counters inc/dec in case of error + if blocking { + atomic.AddInt32(&c.nbOperCnt, 1) + } else { + atomic.AddInt32(&c.nbOperCnt, -1) + } + + if ctrlErr != nil { + return ctrlErr + } + + if err != nil { + return err + } } - if err == nil { - c.isNonBlocking = !blocking - } - - return err + return nil }