Remove nbconn

The non-blocking IO system was designed to solve three problems:

1. Deadlock that can occur when both sides of a connection are blocked
   writing because all buffers between are full.
2. The inability to use a write deadline with a TLS.Conn without killing
   the connection.
3. Efficiently check if a connection has been closed before writing.
   This reduces the cases where the application doesn't know if a query
   that does a INSERT/UPDATE/DELETE was actually sent to the server or
   not.

However, the nbconn package is extraordinarily complex, has been a
source of very tricky bugs, and has OS specific code paths. It also does
not work at all with underlying net.Conn implementations that do not
have platform specific non-blocking IO syscall support and do not
properly implement deadlines. In particular, this is the case with
golang.org/x/crypto/ssh.

I believe the deadlock problem can be solved with a combination of a
goroutine for CopyFrom like v4 used and a watchdog for regular queries
that uses time.AfterFunc.

The write deadline problem actually should be ignorable. We check for
context cancellation before sending a query and the actual Write should
be almost instant as long as the underlying connection is not blocked.
(We should only have to wait until it is accepted by the OS, not until
it is fully sent.)

Efficiently checking if a connection has been closed is probably the
hardest to solve without non-blocking reads. However, the existing code
only solves part of the problem. It can detect a closed or broken
connection the OS knows about, but it won't actually detect other types
of broken connections such as a network interruption. This is currently
implemented in CheckConn and called automatically when checking a
connection out of the pool that has been idle for over one second. I
think that changing CheckConn to a very short deadline read and changing
the pool to do an actual Ping would be an acceptable solution.

Remove nbconn and non-blocking code. This does not leave the system in
an entirely working state. In particular, CopyFrom is broken, deadlocks
can occur for extremely large queries or batches, and PgConn.CheckConn
is now a `select 1` ping. These will be resolved in subsequent commits.
pull/1644/head
Jack Christensen 2023-05-27 07:06:22 -05:00 committed by Jack Christensen
parent 9cfdd21f1c
commit 4410fc0a65
8 changed files with 28 additions and 1587 deletions

View File

@ -13,7 +13,6 @@ import (
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/require"
@ -1120,7 +1119,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
}
type queryRecorder struct {
conn nbconn.Conn
conn net.Conn
writeBuf []byte
readCount int
}
@ -1136,14 +1135,6 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
return qr.conn.Write(b)
}
func (qr *queryRecorder) BufferReadUntilBlock() error {
return qr.conn.BufferReadUntilBlock()
}
func (qr *queryRecorder) Flush() error {
return qr.conn.Flush()
}
func (qr *queryRecorder) Close() error {
return qr.conn.Close()
}

View File

@ -1,70 +0,0 @@
package nbconn
import (
"sync"
)
const minBufferQueueLen = 8
type bufferQueue struct {
lock sync.Mutex
queue []*[]byte
r, w int
}
func (bq *bufferQueue) pushBack(buf *[]byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
bq.queue[bq.w] = buf
bq.w++
}
func (bq *bufferQueue) pushFront(buf *[]byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
bq.queue[bq.r] = buf
bq.w++
}
func (bq *bufferQueue) popFront() *[]byte {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.r == bq.w {
return nil
}
buf := bq.queue[bq.r]
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
bq.r++
if bq.r == bq.w {
bq.r = 0
bq.w = 0
if len(bq.queue) > minBufferQueueLen {
bq.queue = make([]*[]byte, minBufferQueueLen)
}
}
return buf
}
func (bq *bufferQueue) growQueue() {
desiredLen := (len(bq.queue) + 1) * 3 / 2
if desiredLen < minBufferQueueLen {
desiredLen = minBufferQueueLen
}
newQueue := make([]*[]byte, desiredLen)
copy(newQueue, bq.queue)
bq.queue = newQueue
}

View File

@ -1,550 +0,0 @@
// Package nbconn implements a non-blocking net.Conn wrapper.
//
// It is designed to solve three problems.
//
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
//
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
package nbconn
import (
"crypto/tls"
"errors"
"net"
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
var errClosed = errors.New("closed")
var ErrWouldBlock = new(wouldBlockError)
const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond
const minNonblockingReadWaitDuration = time.Microsecond
const maxNonblockingReadWaitDuration = 100 * time.Millisecond
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
// mode.
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
// ignore all future calls.
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
type wouldBlockError struct{}
func (*wouldBlockError) Error() string {
return "would block"
}
func (*wouldBlockError) Timeout() bool { return true }
func (*wouldBlockError) Temporary() bool { return true }
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
// the underlying connection.
type Conn interface {
net.Conn
// Flush flushes any buffered writes.
Flush() error
// BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block.
BufferReadUntilBlock() error
}
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
type NetConn struct {
// 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit
// architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and
// https://github.com/jackc/pgx/issues/1307. Only access with atomics
closed int64 // 0 = not closed, 1 = closed
conn net.Conn
rawConn syscall.RawConn
readQueue bufferQueue
writeQueue bufferQueue
readFlushLock sync.Mutex
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
nonblockWriteFunc func(fd uintptr) (done bool)
nonblockWriteBuf []byte
nonblockWriteErr error
nonblockWriteN int
// non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
nonblockReadFunc func(fd uintptr) (done bool)
nonblockReadBuf []byte
nonblockReadErr error
nonblockReadN int
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
fakeNonBlockingShortReadCount int
fakeNonblockingReadWaitDuration time.Duration
writeDeadlineLock sync.Mutex
writeDeadline time.Time
// The following fields are used in nbconn_real_non_block_windows
// nbOperMu Used to prevent concurrent SetBlockingMode calls
nbOperMu sync.Mutex
// nbOperCnt Tracks how many operations performing simultaneously
nbOperCnt int
}
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
nc := &NetConn{
conn: conn,
fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration,
}
if !fakeNonBlockingIO {
if sc, ok := conn.(syscall.Conn); ok {
if rawConn, err := sc.SyscallConn(); err == nil {
nc.rawConn = rawConn
}
}
}
return nc
}
// Read implements io.Reader.
func (c *NetConn) Read(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return 0, err
}
for n < len(b) {
buf := c.readQueue.popFront()
if buf == nil {
break
}
copiedN := copy(b[n:], *buf)
if copiedN < len(*buf) {
*buf = (*buf)[copiedN:]
c.readQueue.pushFront(buf)
} else {
iobufpool.Put(buf)
}
n += copiedN
}
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
if n > 0 {
return n, nil
}
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
var readN int
if readNonblocking {
if setSockModeErr := c.SetBlockingMode(false); setSockModeErr != nil {
return n, setSockModeErr
}
defer func() {
_ = c.SetBlockingMode(true)
}()
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.conn.Read(b[n:])
}
n += readN
return n, err
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
// closed. Call Flush to actually write to the underlying connection.
func (c *NetConn) Write(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
buf := iobufpool.Get(len(b))
copy(*buf, b)
c.writeQueue.pushBack(buf)
return len(b), nil
}
func (c *NetConn) Close() (err error) {
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
if !swapped {
return errClosed
}
defer func() {
closeErr := c.conn.Close()
if err == nil {
err = closeErr
}
}()
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return err
}
return nil
}
func (c *NetConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *NetConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *NetConn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *NetConn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if c.readDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.readDeadline = t
return nil
}
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.conn.SetReadDeadline(t)
}
func (c *NetConn) SetWriteDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
if c.writeDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.writeDeadline = t
return nil
}
c.writeDeadline = t
return c.conn.SetWriteDeadline(t)
}
func (c *NetConn) Flush() error {
if c.isClosed() {
return errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
return c.flush()
}
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *NetConn) flush() error {
var stopChan chan struct{}
var errChan chan error
if err := c.SetBlockingMode(false); err != nil {
return err
}
defer func() {
_ = c.SetBlockingMode(true)
}()
defer func() {
if stopChan != nil {
select {
case stopChan <- struct{}{}:
case <-errChan:
}
}
}()
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := *buf
for len(remainingBuf) > 0 {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
*buf = (*buf)[:len(remainingBuf)]
copy(*buf, remainingBuf)
c.writeQueue.pushFront(buf)
return err
}
// Writing was blocked. Reading might unblock it.
if stopChan == nil {
stopChan, errChan = c.bufferNonblockingRead()
}
select {
case err := <-errChan:
stopChan = nil
return err
default:
}
}
}
iobufpool.Put(buf)
}
return nil
}
func (c *NetConn) BufferReadUntilBlock() error {
if err := c.SetBlockingMode(false); err != nil {
return err
}
defer func() {
_ = c.SetBlockingMode(true)
}()
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(*buf)
if n > 0 {
*buf = (*buf)[:n]
c.readQueue.pushBack(buf)
} else if n == 0 {
iobufpool.Put(buf)
}
if err != nil {
if errors.Is(err, ErrWouldBlock) {
return nil
} else {
return err
}
}
}
}
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
stopChan = make(chan struct{})
errChan = make(chan error, 1)
go func() {
for {
err := c.BufferReadUntilBlock()
if err != nil {
errChan <- err
return
}
select {
case <-stopChan:
return
default:
}
}
}()
return stopChan, errChan
}
func (c *NetConn) isClosed() bool {
closed := atomic.LoadInt64(&c.closed)
return closed == 1
}
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingWrite(b)
} else {
return c.realNonblockingWrite(b)
}
}
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWriteWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.conn.SetWriteDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetWriteDeadline(c.writeDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Write(b)
}
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingRead(b)
} else {
return c.realNonblockingRead(b)
}
}
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
// The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are
// already in Go or the OS's receive buffer.
if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration {
b = b[:1]
}
startTime := time.Now()
deadline := startTime.Add(c.fakeNonblockingReadWaitDuration)
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
err = c.conn.SetReadDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// If the read was successful and the wait duration is not already the minimum
if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration {
endTime := time.Now()
if n > 0 && c.fakeNonBlockingShortReadCount < 5 {
c.fakeNonBlockingShortReadCount++
}
// The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that
// a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive
// buffer.
proposedWait := endTime.Sub(startTime) * 2
if proposedWait < minNonblockingReadWaitDuration {
proposedWait = minNonblockingReadWaitDuration
}
if proposedWait < c.fakeNonblockingReadWaitDuration {
c.fakeNonblockingReadWaitDuration = proposedWait
}
}
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetReadDeadline(c.readDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Read(b)
}
// syscall.Conn is interface
// TLSClient establishes a TLS connection as a client over conn using config.
//
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
// *TLSConn is returned.
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
tc := tls.Client(conn, config)
err := tc.Handshake()
if err != nil {
return nil, err
}
// Ensure last written part of Handshake is actually sent.
err = conn.Flush()
if err != nil {
return nil, err
}
return &TLSConn{
tlsConn: tc,
nbConn: conn,
}, nil
}
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
// tls.Conn.
type TLSConn struct {
tlsConn *tls.Conn
nbConn *NetConn
}
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
func (tc *TLSConn) Close() error {
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
// own 5 second deadline then make all set deadlines no-op.
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
return tc.tlsConn.Close()
}
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }

View File

@ -1,11 +0,0 @@
//go:build !unix && !windows
package nbconn
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
return c.fakeNonblockingWrite(b)
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
return c.fakeNonblockingRead(b)
}

View File

@ -1,86 +0,0 @@
//go:build unix
package nbconn
import (
"errors"
"io"
"syscall"
)
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
if c.nonblockWriteFunc == nil {
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
return true
}
}
c.nonblockWriteBuf = b
c.nonblockWriteN = 0
c.nonblockWriteErr = nil
err = c.rawConn.Write(c.nonblockWriteFunc)
n = c.nonblockWriteN
c.nonblockWriteBuf = nil // ensure that no reference to b is kept.
if err == nil && c.nonblockWriteErr != nil {
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockWriteErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
return n, nil
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
if c.nonblockReadFunc == nil {
c.nonblockReadFunc = func(fd uintptr) (done bool) {
c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf)
return true
}
}
c.nonblockReadBuf = b
c.nonblockReadN = 0
c.nonblockReadErr = nil
err = c.rawConn.Read(c.nonblockReadFunc)
n = c.nonblockReadN
c.nonblockReadBuf = nil // ensure that no reference to b is kept.
if err == nil && c.nonblockReadErr != nil {
if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockReadErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
// syscall read did not return an error and 0 bytes were read means EOF.
if n == 0 {
return 0, io.EOF
}
return n, nil
}
func (c *NetConn) SetBlockingMode(blocking bool) error {
// Do nothing on UNIX systems
return nil
}

View File

@ -1,227 +0,0 @@
//go:build windows
package nbconn
import (
"errors"
"fmt"
"golang.org/x/sys/windows"
"io"
"syscall"
"time"
"unsafe"
)
var dll = syscall.MustLoadDLL("ws2_32.dll")
// int ioctlsocket(
//
// [in] SOCKET s,
// [in] long cmd,
// [in, out] u_long *argp
//
// );
var ioctlsocket = dll.MustFindProc("ioctlsocket")
var deadlineExpErr = errors.New("i/o timeout")
type sockMode int
const (
FIONBIO uint32 = 0x8004667e
sockModeBlocking sockMode = 0
sockModeNonBlocking sockMode = 1
)
func setSockMode(fd uintptr, mode sockMode) error {
res, _, err := ioctlsocket.Call(fd, uintptr(FIONBIO), uintptr(unsafe.Pointer(&mode)))
// Upon successful completion, the ioctlsocket returns zero.
if res != 0 && err != nil {
return err
}
return nil
}
func (c *NetConn) isDeadlineSet(dl time.Time) bool {
return !dl.IsZero() && !dl.Equal(NonBlockingDeadline) && !dl.Equal(disableSetDeadlineDeadline)
}
func (c *NetConn) isWriteDeadlineExpired() bool {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
return c.isDeadlineSet(c.writeDeadline) && !time.Now().Before(c.writeDeadline)
}
func (c *NetConn) isReadDeadlineExpired() bool {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
return c.isDeadlineSet(c.readDeadline) && !time.Now().Before(c.readDeadline)
}
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
if c.nonblockWriteFunc == nil {
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
var written uint32
var buf syscall.WSABuf
buf.Buf = &c.nonblockWriteBuf[0]
buf.Len = uint32(len(c.nonblockWriteBuf))
c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil)
c.nonblockWriteN = int(written)
return true
}
}
c.nonblockWriteBuf = b
c.nonblockWriteN = 0
c.nonblockWriteErr = nil
if c.isWriteDeadlineExpired() {
c.nonblockWriteErr = deadlineExpErr
return 0, c.nonblockWriteErr
}
err = c.rawConn.Write(c.nonblockWriteFunc)
n = c.nonblockWriteN
c.nonblockWriteBuf = nil // ensure that no reference to b is kept.
if err == nil && c.nonblockWriteErr != nil {
if errors.Is(c.nonblockWriteErr, windows.WSAEWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockWriteErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
return n, nil
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
if c.nonblockReadFunc == nil {
c.nonblockReadFunc = func(fd uintptr) (done bool) {
var read uint32
var flags uint32
var buf syscall.WSABuf
buf.Buf = &c.nonblockReadBuf[0]
buf.Len = uint32(len(c.nonblockReadBuf))
c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil)
c.nonblockReadN = int(read)
return true
}
}
c.nonblockReadBuf = b
c.nonblockReadN = 0
c.nonblockReadErr = nil
if c.isReadDeadlineExpired() {
c.nonblockReadErr = deadlineExpErr
return 0, c.nonblockReadErr
}
err = c.rawConn.Read(c.nonblockReadFunc)
n = c.nonblockReadN
c.nonblockReadBuf = nil // ensure that no reference to b is kept.
if err == nil && c.nonblockReadErr != nil {
if errors.Is(c.nonblockReadErr, windows.WSAEWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockReadErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
// syscall read did not return an error and 0 bytes were read means EOF.
if n == 0 {
return 0, io.EOF
}
return n, nil
}
func (c *NetConn) SetBlockingMode(blocking bool) error {
// Fake non-blocking I/O is ignored
if c.rawConn == nil {
return nil
}
// Prevent concurrent SetBlockingMode calls
c.nbOperMu.Lock()
defer c.nbOperMu.Unlock()
// Guard against negative value (which should never happen in practice)
if c.nbOperCnt < 0 {
c.nbOperCnt = 0
}
if blocking {
// Socket is already in blocking mode
if c.nbOperCnt == 0 {
return nil
}
c.nbOperCnt--
// Not ready to exit from non-blocking mode, there is pending non-blocking operations
if c.nbOperCnt > 0 {
return nil
}
} else {
c.nbOperCnt++
// Socket is already in non-blocking mode
if c.nbOperCnt > 1 {
return nil
}
}
mode := sockModeNonBlocking
if blocking {
mode = sockModeBlocking
}
var ctrlErr, err error
ctrlErr = c.rawConn.Control(func(fd uintptr) {
err = setSockMode(fd, mode)
})
if ctrlErr != nil || err != nil {
retErr := ctrlErr
if retErr == nil {
retErr = err
}
// Revert counters inc/dec in case of error
if blocking {
c.nbOperCnt++
return fmt.Errorf("cannot set socket to blocking mode: %w", retErr)
} else {
c.nbOperCnt--
return fmt.Errorf("cannot set socket to non-blocking mode: %w", retErr)
}
}
return nil
}

View File

@ -1,599 +0,0 @@
package nbconn_test
import (
"crypto/tls"
"io"
"net"
"runtime"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test keys generated with:
//
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
E7ZYmaKTMOhvkg==
-----END CERTIFICATE-----`)
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
// source code.
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
2qWm8jTPeDC3sq+67s2oojHf+Q==
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
for _, tt := range []struct {
name string
makeConns func(t *testing.T) (local, remote net.Conn)
useTLS bool
fakeNonBlockingIO bool
}{
{
name: "Pipe",
makeConns: makePipeConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TLS over TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: true,
},
{
name: "TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: false,
},
{
name: "TLS over TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: false,
},
} {
t.Run(tt.name, func(t *testing.T) {
local, remote := tt.makeConns(t)
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
// uses remote it may be garbage collected leading to the connection being closed.
defer local.Close()
defer remote.Close()
var conn nbconn.Conn
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
if tt.useTLS {
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
require.NoError(t, err)
tlsServer := tls.Server(remote, &tls.Config{
Certificates: []tls.Certificate{cert},
})
serverTLSHandshakeChan := make(chan error)
go func() {
err := tlsServer.Handshake()
serverTLSHandshakeChan <- err
}()
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
require.NoError(t, err)
conn = tlsConn
err = <-serverTLSHandshakeChan
require.NoError(t, err)
remote = tlsServer
} else {
conn = netConn
}
f(t, conn, remote)
})
}
}
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
func makePipeConns(t *testing.T) (local, remote net.Conn) {
local, remote = net.Pipe()
t.Cleanup(func() {
local.Close()
remote.Close()
})
return local, remote
}
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
type acceptResultT struct {
conn net.Conn
err error
}
acceptChan := make(chan acceptResultT)
go func() {
conn, err := ln.Accept()
acceptChan <- acceptResultT{conn: conn, err: err}
}()
local, err = net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
acceptResult := <-acceptChan
require.NoError(t, acceptResult.err)
remote = acceptResult.conn
return local, remote
}
func TestWriteIsBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
// net.Pipe is synchronous so the Write would block if not buffered.
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
err := conn.Flush()
errChan <- err
}()
readBuf := make([]byte, len(writeBuf))
_, err = remote.Read(readBuf)
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetWriteDeadline(time.Now())
require.NoError(t, err)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestReadFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 2)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
_, err = remote.Write([]byte("okay"))
errChan <- err
}()
readBuf := make([]byte, 4)
_, err = conn.Read(readBuf)
require.NoError(t, err)
require.Equal(t, []byte("okay"), readBuf)
require.NoError(t, <-errChan)
require.NoError(t, <-errChan)
})
}
func TestCloseFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
}()
err = conn.Close()
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
// large values.
func TestInternalNonBlockingWrite(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
errChan := make(chan error, 1)
go func() {
remoteWriteBuf := make([]byte, deadlockSize)
_, err := remote.Write(remoteWriteBuf)
if err != nil {
errChan <- err
return
}
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(remote, readBuf)
errChan <- err
}()
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(conn, readBuf)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
if runtime.GOOS == "windows" && t.Name() == "TestInternalNonBlockingWrite/TLS_over_TCP_with_Fake_Non-blocking_IO" {
// this test is expected to fail on Windows see https://github.com/golang/go/issues/58764
require.Error(t, <-errChan)
} else {
require.NoError(t, <-errChan)
}
})
}
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
err = conn.SetDeadline(time.Now())
require.NoError(t, err)
err = conn.Flush()
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
})
}
func TestNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
require.NoError(t, err)
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
require.EqualValues(t, 0, n)
errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()
err = conn.SetReadDeadline(time.Time{})
require.NoError(t, err)
n, err = conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestBufferNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.BufferReadUntilBlock()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
err := remote.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err != nil {
errChan <- err
return
}
_, err = remote.Write([]byte("okay"))
errChan <- err
}()
readLoop:
for i := 0; i < 1000; i++ {
err := conn.BufferReadUntilBlock()
require.NoError(t, err)
select {
case err := <-errChan:
require.NoError(t, err)
break readLoop
default:
time.Sleep(time.Millisecond)
}
}
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.EqualValues(t, 4, n)
assert.Equal(t, []byte("okay"), buf)
})
}
func TestReadPreviouslyBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 5)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 5, n)
require.Equal(t, []byte("alpha"), readBuf)
})
}
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 10)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 5, n)
require.Equal(t, []byte("alpha"), readBuf[:n])
})
}
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 2)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 2, n)
require.Equal(t, []byte("al"), readBuf)
readBuf = make([]byte, 3)
n, err = conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("pha"), readBuf)
})
}
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
_, err = remote.Write([]byte("beta"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 9)
n, err := io.ReadFull(conn, readBuf)
require.NoError(t, err)
require.EqualValues(t, 9, n)
require.Equal(t, []byte("alphabeta"), readBuf)
})
}
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
flushCompleteChan := make(chan struct{})
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
<-flushCompleteChan
_, err = remote.Write([]byte("beta"))
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
close(flushCompleteChan)
readBuf := make([]byte, 9)
n, err := io.ReadFull(conn, readBuf)
require.NoError(t, err)
require.EqualValues(t, 9, n)
require.Equal(t, []byte("alphabeta"), readBuf)
err = <-errChan
require.NoError(t, err)
})
}

View File

@ -16,7 +16,6 @@ import (
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
@ -65,7 +64,7 @@ type NotificationHandler func(*PgConn, *Notification)
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection
conn net.Conn
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
@ -266,14 +265,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}
nbNetConn := nbconn.NewNetConn(netConn, false)
pgConn.conn = nbNetConn
pgConn.contextWatcher = newContextWatcher(nbNetConn)
pgConn.conn = netConn
pgConn.contextWatcher = newContextWatcher(netConn)
pgConn.contextWatcher.Watch(ctx)
if fallbackConfig.TLSConfig != nil {
nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig)
nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
@ -392,7 +390,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
)
}
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) {
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return nil, err
@ -407,12 +405,7 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err
return nil, errors.New("server refused TLS connection")
}
tlsConn, err := nbconn.TLSClient(conn, tlsConfig)
if err != nil {
return nil, err
}
return tlsConn, nil
return tls.Client(conn, tlsConfig), nil
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
@ -468,10 +461,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.frontend.Receive()
if err != nil {
if errors.Is(err, nbconn.ErrWouldBlock) {
return nil, err
}
// Close on anything other than timeout error - everything else is fatal
var netErr net.Error
isNetErr := errors.As(err, &netErr)
@ -1174,11 +1163,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
return CommandTag{}, err
}
err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
// err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
// if err != nil {
// pgConn.asyncClose()
// return CommandTag{}, err
// }
nonblocking := true
defer func() {
if nonblocking {
@ -1217,9 +1206,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
for pgErr == nil {
msg, err := pgConn.receiveMessage()
if err != nil {
if errors.Is(err, nbconn.ErrWouldBlock) {
break
}
// if errors.Is(err, nbconn.ErrWouldBlock) {
// break
// }
pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
@ -1638,15 +1627,19 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
// without the client knowing whether the server received it or not.
func (pgConn *PgConn) CheckConn() error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
// if err := pgConn.lock(); err != nil {
// return err
// }
// defer pgConn.unlock()
err := pgConn.conn.BufferReadUntilBlock()
if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) {
return err
}
rr := pgConn.ExecParams(context.Background(), "select 1", nil, nil, nil, nil)
_, err := rr.Close()
return err
// err := pgConn.conn.BufferReadUntilBlock()
// if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) {
// return err
// }
return nil
}
@ -1660,7 +1653,7 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility.
type HijackedConn struct {
Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection
Conn net.Conn
PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server