Try to make windows non-blocking I/O

pull/1540/head
Dmitry K 2023-02-26 00:00:37 +03:00 committed by Jack Christensen
parent c09ddaf440
commit 087b8b2ba8
3 changed files with 62 additions and 2 deletions

1
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/jackc/puddle/v2 v2.2.0
github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.6.0
golang.org/x/sys v0.5.0
golang.org/x/text v0.7.0
)

2
go.sum
View File

@ -33,6 +33,8 @@ golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -4,20 +4,64 @@ package nbconn
import (
"errors"
"golang.org/x/sys/windows"
"io"
"syscall"
"unsafe"
)
var dll = syscall.MustLoadDLL("ws2_32.dll")
// int ioctlsocket(
//
// [in] SOCKET s,
// [in] long cmd,
// [in, out] u_long *argp
//
// );
var ioctlsocket = dll.MustFindProc("ioctlsocket")
type sockMode int
const (
FIONBIO int = 0x8004667e
sockModeBlocking sockMode = 0
sockModeNonBlocking sockMode = 1
)
func setSockMode(fd uintptr, mode sockMode) error {
res, _, err := ioctlsocket.Call(fd, uintptr(FIONBIO), uintptr(unsafe.Pointer(&mode)))
// Upon successful completion, the ioctlsocket returns zero.
if res != 0 && err != nil {
return err
}
return nil
}
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
if c.nonblockWriteFunc == nil {
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
// Make sock non-blocking
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
c.nonblockWriteErr = err
return true
}
var written uint32
var buf syscall.WSABuf
buf.Buf = &c.nonblockWriteBuf[0]
buf.Len = uint32(len(c.nonblockWriteBuf))
c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil)
c.nonblockWriteN = int(written)
// Make sock blocking again
if err := setSockMode(fd, sockModeBlocking); err != nil {
c.nonblockWriteErr = err
return true
}
return true
}
}
@ -29,7 +73,7 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
n = c.nonblockWriteN
c.nonblockWriteBuf = nil // ensure that no reference to b is kept.
if err == nil && c.nonblockWriteErr != nil {
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
if errors.Is(c.nonblockWriteErr, windows.WSAEWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockWriteErr
@ -50,6 +94,12 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
if c.nonblockReadFunc == nil {
c.nonblockReadFunc = func(fd uintptr) (done bool) {
// Make sock non-blocking
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
c.nonblockWriteErr = err
return true
}
var read uint32
var flags uint32
var buf syscall.WSABuf
@ -57,6 +107,13 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
buf.Len = uint32(len(c.nonblockReadBuf))
c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil)
c.nonblockReadN = int(read)
// Make sock blocking again
if err := setSockMode(fd, sockModeBlocking); err != nil {
c.nonblockWriteErr = err
return true
}
return true
}
}
@ -68,7 +125,7 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
n = c.nonblockReadN
c.nonblockReadBuf = nil // ensure that no reference to b is kept.
if err == nil && c.nonblockReadErr != nil {
if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) {
if errors.Is(c.nonblockReadErr, windows.WSAEWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockReadErr