pgx/internal/nbconn/nbconn.go

353 lines
7.3 KiB
Go

// Package nbconn implements a non-blocking net.Conn wrapper.
package nbconn
import (
"crypto/tls"
"errors"
"net"
"os"
"sync"
"sync/atomic"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
var errClosed = errors.New("closed")
var ErrWouldBlock = errors.New("would block")
const fakeNonblockingWaitDuration = 100 * time.Millisecond
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
// Conn is a non-blocking net.Conn wrapper. It implements net.Conn.
//
// It is designed to solve three problems.
//
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
//
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
type Conn struct {
netConn net.Conn
tlsConn *tls.Conn
maybeTLSConn net.Conn
readQueue bufferQueue
writeQueue bufferQueue
readFlushLock sync.Mutex
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
writeDeadlineLock sync.Mutex
writeDeadline time.Time
// Only access with atomics
closed int64 // 0 = not closed, 1 = closed
}
func New(conn net.Conn) *Conn {
return &Conn{
netConn: conn,
maybeTLSConn: conn,
}
}
// StartTLS starts using TLS. It must not be called concurrently with any other method and must only be called once.
func (c *Conn) StartTLS(config *tls.Config) {
c.tlsConn = tls.Client(c.netConn, config)
c.maybeTLSConn = c.tlsConn
}
// Read implements io.Reader.
func (c *Conn) Read(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return 0, err
}
for n < len(b) {
buf := c.readQueue.popFront()
if buf == nil {
break
}
copiedN := copy(b[n:], buf)
if copiedN < len(buf) {
buf = buf[copiedN:]
c.readQueue.pushFront(buf)
} else {
releaseBuf(buf)
}
n += copiedN
}
if n == len(b) {
return n, nil
}
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
var readN int
if readNonblocking {
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.maybeTLSConn.Read(b[n:])
}
n += readN
return n, err
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
// closed. Call Flush to actually write to the underlying connection.
func (c *Conn) Write(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
buf := iobufpool.Get(len(b))
copy(buf, b)
c.writeQueue.pushBack(buf)
return len(b), nil
}
func (c *Conn) Close() (err error) {
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
if !swapped {
return errClosed
}
defer func() {
closeErr := c.maybeTLSConn.Close()
if err == nil {
err = closeErr
}
}()
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return err
}
return nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.maybeTLSConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.maybeTLSConn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *Conn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.maybeTLSConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
c.writeDeadline = t
return c.netConn.SetWriteDeadline(t)
}
func (c *Conn) Flush() error {
if c.isClosed() {
return errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
return c.flush()
}
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *Conn) flush() error {
var stopChan chan struct{}
var errChan chan error
defer func() {
if stopChan != nil {
select {
case stopChan <- struct{}{}:
case <-errChan:
}
}
}()
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := buf
for len(remainingBuf) > 0 {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
return err
}
// Writing was blocked. Reading might unblock it.
if stopChan == nil {
stopChan, errChan = c.bufferNonblockingRead()
}
select {
case err := <-errChan:
stopChan = nil
return err
default:
}
}
}
releaseBuf(buf)
}
return nil
}
func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
stopChan = make(chan struct{})
errChan = make(chan error, 1)
go func() {
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(buf)
if n > 0 {
buf = buf[:n]
c.readQueue.pushBack(buf)
}
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
errChan <- err
return
}
}
select {
case <-stopChan:
return
default:
}
}
}()
return stopChan, errChan
}
func (c *Conn) isClosed() bool {
closed := atomic.LoadInt64(&c.closed)
return closed == 1
}
func (c *Conn) nonblockingWrite(b []byte) (n int, err error) {
return c.fakeNonblockingWrite(b)
}
func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.netConn.SetWriteDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.netConn.SetWriteDeadline(c.writeDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.netConn.Write(b)
}
func (c *Conn) nonblockingRead(b []byte) (n int, err error) {
return c.fakeNonblockingRead(b)
}
func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
err = c.netConn.SetReadDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.netConn.SetReadDeadline(c.readDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.netConn.Read(b)
}
// syscall.Conn is interface