non-blocking
Jack Christensen 2022-06-02 19:32:55 -05:00
parent 2e7b46d5d7
commit ca22396789
5 changed files with 371 additions and 117 deletions

View File

@ -0,0 +1,76 @@
package nbbconn
import (
"sync"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
const minBufferQueueLen = 8
type bufferQueue struct {
lock sync.Mutex
queue [][]byte
r, w int
}
func (bq *bufferQueue) pushBack(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
bq.queue[bq.w] = buf
bq.w++
}
func (bq *bufferQueue) pushFront(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
bq.queue[bq.r] = buf
bq.w++
}
func (bq *bufferQueue) popFront() []byte {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.r == bq.w {
return nil
}
buf := bq.queue[bq.r]
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
bq.r++
if bq.r == bq.w {
bq.r = 0
bq.w = 0
if len(bq.queue) > minBufferQueueLen {
bq.queue = make([][]byte, minBufferQueueLen)
}
}
return buf
}
func (bq *bufferQueue) growQueue() {
desiredLen := (len(bq.queue) + 1) * 3 / 2
if desiredLen < minBufferQueueLen {
desiredLen = minBufferQueueLen
}
newQueue := make([][]byte, desiredLen)
copy(newQueue, bq.queue)
bq.queue = newQueue
}
func releaseBuf(buf []byte) {
iobufpool.Put(buf[:cap(buf)])
}

View File

@ -13,10 +13,12 @@ import (
)
var errClosed = errors.New("closed")
var errWouldBlock = errors.New("would block")
var ErrWouldBlock = errors.New("would block")
const fakeNonblockingWaitDuration = 100 * time.Millisecond
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
// Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn.
//
// It is designed to solve three problems.
@ -37,6 +39,7 @@ type Conn struct {
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
writeDeadlineLock sync.Mutex
writeDeadline time.Time
@ -74,9 +77,19 @@ func (c *Conn) Read(b []byte) (n int, err error) {
releaseBuf(buf)
}
return n, nil
// TODO - must return error if n != len(b)
}
return c.netConn.Read(b)
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
if readNonblocking {
return c.nonblockingRead(b)
} else {
return c.netConn.Read(b)
}
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
@ -123,22 +136,16 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.netConn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *Conn) SetDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
err := c.SetReadDeadline(t)
if err != nil {
return err
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
c.readDeadline = t
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
c.writeDeadline = t
return c.netConn.SetDeadline(t)
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
@ -146,6 +153,14 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.netConn.SetReadDeadline(t)
@ -193,7 +208,7 @@ func (c *Conn) flush() error {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, errWouldBlock) {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
@ -234,7 +249,7 @@ func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan err
}
if err != nil {
if !errors.Is(err, errWouldBlock) {
if !errors.Is(err, ErrWouldBlock) {
errChan <- err
return
}
@ -276,7 +291,7 @@ func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) {
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = errWouldBlock
err = ErrWouldBlock
}
}
}()
@ -305,7 +320,7 @@ func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) {
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = errWouldBlock
err = ErrWouldBlock
}
}
}()

View File

@ -0,0 +1,129 @@
package nbbconn_test
import (
"net"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/nbbconn"
"github.com/stretchr/testify/require"
)
func TestWriteIsBuffered(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()
conn := nbbconn.New(local)
// net.Pipe is synchronous so the Write would block if not buffered.
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
err := conn.Flush()
errChan <- err
}()
readBuf := make([]byte, len(writeBuf))
_, err = remote.Read(readBuf)
require.NoError(t, err)
require.NoError(t, <-errChan)
}
func TestReadFlushesWriteBuffer(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()
conn := nbbconn.New(local)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 2)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
_, err = remote.Write([]byte("okay"))
errChan <- err
}()
readBuf := make([]byte, 4)
_, err = conn.Read(readBuf)
require.NoError(t, err)
require.Equal(t, []byte("okay"), readBuf)
require.NoError(t, <-errChan)
require.NoError(t, <-errChan)
}
func TestCloseFlushesWriteBuffer(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()
conn := nbbconn.New(local)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
}()
err = conn.Close()
require.NoError(t, err)
require.NoError(t, <-errChan)
}
func TestNonBlockingRead(t *testing.T) {
local, remote := net.Pipe()
defer func() {
local.Close()
remote.Close()
}()
conn := nbbconn.New(local)
err := conn.SetReadDeadline(nbbconn.NonBlockingDeadline)
require.NoError(t, err)
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.ErrorIs(t, err, nbbconn.ErrWouldBlock)
require.EqualValues(t, 0, n)
errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()
err = conn.SetReadDeadline(time.Time{})
require.NoError(t, err)
n, err = conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
}

View File

@ -0,0 +1,75 @@
package nbbconn
import (
"sync"
)
const minQueueLen = 8
type queue[T any] struct {
lock sync.Mutex
queue []T
r, w int
}
func (q *queue[T]) pushBack(item T) {
q.lock.Lock()
defer q.lock.Unlock()
if q.w >= len(q.queue) {
q.growQueue()
}
q.queue[q.w] = item
q.w++
}
func (q *queue[T]) pushFront(item T) {
q.lock.Lock()
defer q.lock.Unlock()
if q.w >= len(q.queue) {
q.growQueue()
}
copy(q.queue[q.r+1:q.w+1], q.queue[q.r:q.w])
q.queue[q.r] = item
q.w++
}
func (q *queue[T]) popFront() (T, bool) {
q.lock.Lock()
defer q.lock.Unlock()
if q.r == q.w {
var zero T
return zero, false
}
item := q.queue[q.r]
// Clear reference so it can be garbage collected.
var zero T
q.queue[q.r] = zero
q.r++
if q.r == q.w {
q.r = 0
q.w = 0
if len(q.queue) > minQueueLen {
q.queue = make([]T, minQueueLen)
}
}
return item, true
}
func (q *queue[T]) growQueue() {
desiredLen := (len(q.queue) + 1) * 3 / 2
if desiredLen < minQueueLen {
desiredLen = minQueueLen
}
newQueue := make([]T, desiredLen)
copy(newQueue, q.queue)
q.queue = newQueue
}

View File

@ -13,9 +13,9 @@ import (
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/nbbconn"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
@ -76,11 +76,6 @@ type PgConn struct {
status byte // One of connStatus* constants
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
@ -254,6 +249,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
pgConn.conn = nbbconn.New(pgConn.conn)
pgConn.contextWatcher.Unwatch() // context watcher should watch nbbconn
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
defer pgConn.contextWatcher.Unwatch()
@ -388,24 +385,6 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
return ch
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@ -445,25 +424,13 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil
}
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.frontend.Receive()
}
msg, err := pgConn.frontend.Receive()
if err != nil {
if errors.Is(err, nbbconn.ErrWouldBlock) {
return nil, err
}
// Close on anything other than timeout error - everything else is fatal
var netErr net.Error
isNetErr := errors.As(err, &netErr)
@ -482,13 +449,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := pgConn.peekMessage()
if err != nil {
// Close on anything other than timeout error - everything else is fatal
var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
}
return nil, err
}
pgConn.peekedMsg = nil
@ -1176,62 +1136,57 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// Send copy to command
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
// Send copy data
abortCopyChan := make(chan struct{})
copyErrChan := make(chan error, 1)
signalMessageChan := pgConn.signalMessage()
senderDoneChan := make(chan struct{})
go func() {
defer close(senderDoneChan)
buf := make([]byte, 0, 65536)
buf = append(buf, 'd')
sp := len(buf)
for {
n, readErr := r.Read(buf[5:cap(buf)])
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
if writeErr != nil {
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
pgConn.conn.Close()
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
select {
case <-abortCopyChan:
return
default:
}
err = pgConn.conn.SetReadDeadline(nbbconn.NonBlockingDeadline)
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
nonblocking := true
defer func() {
if nonblocking {
pgConn.conn.SetReadDeadline(time.Time{})
}
}()
var pgErr error
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <-copyErrChan:
case <-signalMessageChan:
buf := iobufpool.Get(65536)
buf[0] = 'd'
var readErr, pgErr error
for {
// Read chunk from r.
var n int
n, readErr = r.Read(buf[5:cap(buf)])
// Send chunk to PostgreSQL.
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
if writeErr != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
}
// Abort loop if there was a read error.
if readErr != nil {
break
}
// Read messages until error or none available.
for {
msg, err := pgConn.receiveMessage()
if err != nil {
if errors.Is(err, nbbconn.ErrWouldBlock) {
break
}
pgConn.asyncClose()
return CommandTag{}, preferContextOverNetTimeoutError(ctx, err)
}
@ -1239,18 +1194,22 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
default:
signalMessageChan = pgConn.signalMessage()
break
}
}
}
close(abortCopyChan)
<-senderDoneChan
if copyErr == io.EOF || pgErr != nil {
err = pgConn.conn.SetReadDeadline(time.Time{})
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
nonblocking = false
if readErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
}
err = pgConn.frontend.Flush()
if err != nil {