// Package nbbconn implements a non-blocking, buffered net.Conn wrapper. package nbbconn import ( "errors" "net" "os" "sync" "sync/atomic" "time" "github.com/jackc/pgx/v5/internal/iobufpool" ) var errClosed = errors.New("closed") var ErrWouldBlock = errors.New("would block") const fakeNonblockingWaitDuration = 100 * time.Millisecond var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) // Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn. // // It is designed to solve three problems. // // The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all // buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. // // The second is the inability to use a write deadline with a TLS.Conn without killing the connection. // // The third is to efficiently check if a connection has been closed via a non-blocking read. type Conn struct { netConn net.Conn readQueue bufferQueue writeQueue bufferQueue readFlushLock sync.Mutex readDeadlineLock sync.Mutex readDeadline time.Time readNonblocking bool writeDeadlineLock sync.Mutex writeDeadline time.Time // Only access with atomics closed int64 // 0 = not closed, 1 = closed } func New(conn net.Conn) *Conn { return &Conn{ netConn: conn, } } func (c *Conn) Read(b []byte) (n int, err error) { if c.isClosed() { return 0, errClosed } c.readFlushLock.Lock() defer c.readFlushLock.Unlock() err = c.flush() if err != nil { return 0, err } buf := c.readQueue.popFront() if buf != nil { n = copy(b, buf) if n < len(buf) { buf = buf[n:] c.readQueue.pushFront(buf) } else { releaseBuf(buf) } return n, nil // TODO - must return error if n != len(b) } var readNonblocking bool c.readDeadlineLock.Lock() readNonblocking = c.readNonblocking c.readDeadlineLock.Unlock() if readNonblocking { return c.nonblockingRead(b) } else { return c.netConn.Read(b) } } // Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is // closed. Call Flush to actually write to the underlying connection. func (c *Conn) Write(b []byte) (n int, err error) { if c.isClosed() { return 0, errClosed } buf := iobufpool.Get(len(b)) copy(buf, b) c.writeQueue.pushBack(buf) return len(b), nil } func (c *Conn) Close() (err error) { swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) if !swapped { return errClosed } defer func() { closeErr := c.netConn.Close() if err == nil { err = closeErr } }() c.readFlushLock.Lock() defer c.readFlushLock.Unlock() err = c.flush() if err != nil { return err } return nil } func (c *Conn) LocalAddr() net.Addr { return c.netConn.LocalAddr() } func (c *Conn) RemoteAddr() net.Addr { return c.netConn.RemoteAddr() } // SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). func (c *Conn) SetDeadline(t time.Time) error { err := c.SetReadDeadline(t) if err != nil { return err } return c.SetWriteDeadline(t) } // SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. func (c *Conn) SetReadDeadline(t time.Time) error { if c.isClosed() { return errClosed } c.readDeadlineLock.Lock() defer c.readDeadlineLock.Unlock() if t == NonBlockingDeadline { c.readNonblocking = true t = time.Time{} } else { c.readNonblocking = false } c.readDeadline = t return c.netConn.SetReadDeadline(t) } func (c *Conn) SetWriteDeadline(t time.Time) error { if c.isClosed() { return errClosed } c.writeDeadlineLock.Lock() defer c.writeDeadlineLock.Unlock() c.writeDeadline = t return c.netConn.SetWriteDeadline(t) } func (c *Conn) Flush() error { if c.isClosed() { return errClosed } c.readFlushLock.Lock() defer c.readFlushLock.Unlock() return c.flush() } // flush does the actual work of flushing the writeQueue. readFlushLock must already be held. func (c *Conn) flush() error { var stopChan chan struct{} var errChan chan error defer func() { if stopChan != nil { select { case stopChan <- struct{}{}: case <-errChan: } } }() for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { remainingBuf := buf for len(remainingBuf) > 0 { n, err := c.nonblockingWrite(remainingBuf) remainingBuf = remainingBuf[n:] if err != nil { if !errors.Is(err, ErrWouldBlock) { buf = buf[:len(remainingBuf)] copy(buf, remainingBuf) c.writeQueue.pushFront(buf) return err } // Writing was blocked. Reading might unblock it. if stopChan == nil { stopChan, errChan = c.bufferNonblockingRead() } select { case err := <-errChan: stopChan = nil return err default: } } } releaseBuf(buf) } return nil } func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { stopChan = make(chan struct{}) errChan = make(chan error, 1) go func() { for { buf := iobufpool.Get(8 * 1024) n, err := c.nonblockingRead(buf) if n > 0 { buf = buf[:n] c.readQueue.pushBack(buf) } if err != nil { if !errors.Is(err, ErrWouldBlock) { errChan <- err return } } select { case <-stopChan: return default: } } }() return stopChan, errChan } func (c *Conn) isClosed() bool { closed := atomic.LoadInt64(&c.closed) return closed == 1 } func (c *Conn) nonblockingWrite(b []byte) (n int, err error) { return c.fakeNonblockingWrite(b) } func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) { c.writeDeadlineLock.Lock() defer c.writeDeadlineLock.Unlock() deadline := time.Now().Add(fakeNonblockingWaitDuration) if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { err = c.netConn.SetWriteDeadline(deadline) if err != nil { return 0, err } defer func() { // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. c.netConn.SetWriteDeadline(c.writeDeadline) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { err = ErrWouldBlock } } }() } return c.netConn.Write(b) } func (c *Conn) nonblockingRead(b []byte) (n int, err error) { return c.fakeNonblockingRead(b) } func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) { c.readDeadlineLock.Lock() defer c.readDeadlineLock.Unlock() deadline := time.Now().Add(fakeNonblockingWaitDuration) if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { err = c.netConn.SetReadDeadline(deadline) if err != nil { return 0, err } defer func() { // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. c.netConn.SetReadDeadline(c.readDeadline) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { err = ErrWouldBlock } } }() } return c.netConn.Read(b) } // syscall.Conn is interface