pgx/internal/nbbconn/nbbconn.go

333 lines
6.9 KiB
Go

// Package nbbconn implements a non-blocking, buffered net.Conn wrapper.
package nbbconn
import (
"errors"
"net"
"os"
"sync"
"sync/atomic"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
var errClosed = errors.New("closed")
var ErrWouldBlock = errors.New("would block")
const fakeNonblockingWaitDuration = 100 * time.Millisecond
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
// Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn.
//
// It is designed to solve three problems.
//
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
//
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
type Conn struct {
netConn net.Conn
readQueue bufferQueue
writeQueue bufferQueue
readFlushLock sync.Mutex
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
writeDeadlineLock sync.Mutex
writeDeadline time.Time
// Only access with atomics
closed int64 // 0 = not closed, 1 = closed
}
func New(conn net.Conn) *Conn {
return &Conn{
netConn: conn,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return 0, err
}
buf := c.readQueue.popFront()
if buf != nil {
n = copy(b, buf)
if n < len(buf) {
buf = buf[n:]
c.readQueue.pushFront(buf)
} else {
releaseBuf(buf)
}
return n, nil
// TODO - must return error if n != len(b)
}
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
if readNonblocking {
return c.nonblockingRead(b)
} else {
return c.netConn.Read(b)
}
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
// closed. Call Flush to actually write to the underlying connection.
func (c *Conn) Write(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
buf := iobufpool.Get(len(b))
copy(buf, b)
c.writeQueue.pushBack(buf)
return len(b), nil
}
func (c *Conn) Close() (err error) {
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
if !swapped {
return errClosed
}
defer func() {
closeErr := c.netConn.Close()
if err == nil {
err = closeErr
}
}()
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return err
}
return nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.netConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.netConn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *Conn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.netConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
c.writeDeadline = t
return c.netConn.SetWriteDeadline(t)
}
func (c *Conn) Flush() error {
if c.isClosed() {
return errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
return c.flush()
}
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *Conn) flush() error {
var stopChan chan struct{}
var errChan chan error
defer func() {
if stopChan != nil {
select {
case stopChan <- struct{}{}:
case <-errChan:
}
}
}()
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := buf
for len(remainingBuf) > 0 {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
return err
}
// Writing was blocked. Reading might unblock it.
if stopChan == nil {
stopChan, errChan = c.bufferNonblockingRead()
}
select {
case err := <-errChan:
stopChan = nil
return err
default:
}
}
}
releaseBuf(buf)
}
return nil
}
func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
stopChan = make(chan struct{})
errChan = make(chan error, 1)
go func() {
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(buf)
if n > 0 {
buf = buf[:n]
c.readQueue.pushBack(buf)
}
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
errChan <- err
return
}
}
select {
case <-stopChan:
return
default:
}
}
}()
return stopChan, errChan
}
func (c *Conn) isClosed() bool {
closed := atomic.LoadInt64(&c.closed)
return closed == 1
}
func (c *Conn) nonblockingWrite(b []byte) (n int, err error) {
return c.fakeNonblockingWrite(b)
}
func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.netConn.SetWriteDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.netConn.SetWriteDeadline(c.writeDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.netConn.Write(b)
}
func (c *Conn) nonblockingRead(b []byte) (n int, err error) {
return c.fakeNonblockingRead(b)
}
func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
err = c.netConn.SetReadDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.netConn.SetReadDeadline(c.readDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.netConn.Read(b)
}
// syscall.Conn is interface