From 8b5e8d9d89a3dbaf222a639254b77016d8e39102 Mon Sep 17 00:00:00 2001 From: Dmitry K Date: Sat, 18 Mar 2023 21:42:55 +0300 Subject: [PATCH] Fix Windows non-blocking I/O for CopyFrom Created based on discussion here: https://github.com/jackc/pgx/pull/1525#pullrequestreview-1344511991 Fixes https://github.com/jackc/pgx/issues/1552 --- copy_from.go | 12 ++++ internal/nbconn/nbconn.go | 3 + internal/nbconn/nbconn_real_non_block.go | 5 ++ .../nbconn/nbconn_real_non_block_windows.go | 59 ++++++++++++++----- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/copy_from.go b/copy_from.go index a2c227fd..37351848 100644 --- a/copy_from.go +++ b/copy_from.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/jackc/pgx/v5/internal/nbconn" "io" "github.com/jackc/pgx/v5/internal/pgio" @@ -134,6 +135,17 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { r, w := io.Pipe() doneChan := make(chan struct{}) + if realNbConn, ok := ct.conn.pgConn.Conn().(*nbconn.NetConn); ok { + if err := realNbConn.SetBlockingMode(false); err != nil { + return 0, fmt.Errorf("cannot set socket non-blocking mode: %w", err) + } + + defer func() { + // TODO: Deal with it + _ = realNbConn.SetBlockingMode(true) + }() + } + go func() { defer close(doneChan) diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 7a38383f..0c03ba7a 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -96,6 +96,9 @@ type NetConn struct { writeDeadlineLock sync.Mutex writeDeadline time.Time + + // Indicates that underlying socket connection mode set to be non-blocking + isNonBlocking bool } func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { diff --git a/internal/nbconn/nbconn_real_non_block.go b/internal/nbconn/nbconn_real_non_block.go index e93372f2..3a6ed54a 100644 --- a/internal/nbconn/nbconn_real_non_block.go +++ b/internal/nbconn/nbconn_real_non_block.go @@ -79,3 +79,8 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { return n, nil } + +func (c *NetConn) SetBlockingMode(blocking bool) error { + // for UNIX do nothing + return nil +} diff --git a/internal/nbconn/nbconn_real_non_block_windows.go b/internal/nbconn/nbconn_real_non_block_windows.go index ed1662e0..39144ea4 100644 --- a/internal/nbconn/nbconn_real_non_block_windows.go +++ b/internal/nbconn/nbconn_real_non_block_windows.go @@ -43,10 +43,12 @@ func setSockMode(fd uintptr, mode sockMode) error { func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { if c.nonblockWriteFunc == nil { c.nonblockWriteFunc = func(fd uintptr) (done bool) { - // Make sock non-blocking - if err := setSockMode(fd, sockModeNonBlocking); err != nil { - c.nonblockWriteErr = err - return true + if !c.isNonBlocking { + // Make sock non-blocking + if err := setSockMode(fd, sockModeNonBlocking); err != nil { + c.nonblockWriteErr = err + return true + } } var written uint32 @@ -56,10 +58,12 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil) c.nonblockWriteN = int(written) - // Make sock blocking again - if err := setSockMode(fd, sockModeBlocking); err != nil { - c.nonblockWriteErr = err - return true + if !c.isNonBlocking { + // Make sock blocking again + if err := setSockMode(fd, sockModeBlocking); err != nil { + c.nonblockWriteErr = err + return true + } } return true @@ -94,10 +98,12 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { if c.nonblockReadFunc == nil { c.nonblockReadFunc = func(fd uintptr) (done bool) { - // Make sock non-blocking - if err := setSockMode(fd, sockModeNonBlocking); err != nil { - c.nonblockReadErr = err - return true + if !c.isNonBlocking { + // Make sock non-blocking + if err := setSockMode(fd, sockModeNonBlocking); err != nil { + c.nonblockReadErr = err + return true + } } var read uint32 @@ -108,10 +114,12 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil) c.nonblockReadN = int(read) - // Make sock blocking again - if err := setSockMode(fd, sockModeBlocking); err != nil { - c.nonblockReadErr = err - return true + if !c.isNonBlocking { + // Make sock blocking again + if err := setSockMode(fd, sockModeBlocking); err != nil { + c.nonblockReadErr = err + return true + } } return true @@ -147,3 +155,22 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { return n, nil } + +func (c *NetConn) SetBlockingMode(blocking bool) error { + mode := sockModeNonBlocking + if blocking { + mode = sockModeBlocking + } + + var err error + + c.rawConn.Control(func(fd uintptr) { + err = setSockMode(fd, mode) + }) + + if err == nil { + c.isNonBlocking = !blocking + } + + return err +}