From ca22396789d120ff556f9704f4470268fbc8c0d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 2 Jun 2022 19:32:55 -0500 Subject: [PATCH] wip --- internal/nbbconn/bufferqueue.go.deleted | 76 ++++++++++++ internal/nbbconn/nbbconn.go | 51 +++++--- internal/nbbconn/nbbconn_test.go | 129 +++++++++++++++++++ internal/nbbconn/queue.go.deleted | 75 +++++++++++ pgconn/pgconn.go | 157 +++++++++--------------- 5 files changed, 371 insertions(+), 117 deletions(-) create mode 100644 internal/nbbconn/bufferqueue.go.deleted create mode 100644 internal/nbbconn/nbbconn_test.go create mode 100644 internal/nbbconn/queue.go.deleted diff --git a/internal/nbbconn/bufferqueue.go.deleted b/internal/nbbconn/bufferqueue.go.deleted new file mode 100644 index 00000000..65a895bc --- /dev/null +++ b/internal/nbbconn/bufferqueue.go.deleted @@ -0,0 +1,76 @@ +package nbbconn + +import ( + "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +const minBufferQueueLen = 8 + +type bufferQueue struct { + lock sync.Mutex + queue [][]byte + r, w int +} + +func (bq *bufferQueue) pushBack(buf []byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + bq.queue[bq.w] = buf + bq.w++ +} + +func (bq *bufferQueue) pushFront(buf []byte) { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.w >= len(bq.queue) { + bq.growQueue() + } + copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) + bq.queue[bq.r] = buf + bq.w++ +} + +func (bq *bufferQueue) popFront() []byte { + bq.lock.Lock() + defer bq.lock.Unlock() + + if bq.r == bq.w { + return nil + } + + buf := bq.queue[bq.r] + bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. + bq.r++ + + if bq.r == bq.w { + bq.r = 0 + bq.w = 0 + if len(bq.queue) > minBufferQueueLen { + bq.queue = make([][]byte, minBufferQueueLen) + } + } + + return buf +} + +func (bq *bufferQueue) growQueue() { + desiredLen := (len(bq.queue) + 1) * 3 / 2 + if desiredLen < minBufferQueueLen { + desiredLen = minBufferQueueLen + } + + newQueue := make([][]byte, desiredLen) + copy(newQueue, bq.queue) + bq.queue = newQueue +} + +func releaseBuf(buf []byte) { + iobufpool.Put(buf[:cap(buf)]) +} diff --git a/internal/nbbconn/nbbconn.go b/internal/nbbconn/nbbconn.go index 2ee5a8bf..d659c6cc 100644 --- a/internal/nbbconn/nbbconn.go +++ b/internal/nbbconn/nbbconn.go @@ -13,10 +13,12 @@ import ( ) var errClosed = errors.New("closed") -var errWouldBlock = errors.New("would block") +var ErrWouldBlock = errors.New("would block") const fakeNonblockingWaitDuration = 100 * time.Millisecond +var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) + // Conn is a non-blocking, buffered net.Conn wrapper. It implements net.Conn. // // It is designed to solve three problems. @@ -37,6 +39,7 @@ type Conn struct { readDeadlineLock sync.Mutex readDeadline time.Time + readNonblocking bool writeDeadlineLock sync.Mutex writeDeadline time.Time @@ -74,9 +77,19 @@ func (c *Conn) Read(b []byte) (n int, err error) { releaseBuf(buf) } return n, nil + // TODO - must return error if n != len(b) } - return c.netConn.Read(b) + var readNonblocking bool + c.readDeadlineLock.Lock() + readNonblocking = c.readNonblocking + c.readDeadlineLock.Unlock() + + if readNonblocking { + return c.nonblockingRead(b) + } else { + return c.netConn.Read(b) + } } // Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is @@ -123,22 +136,16 @@ func (c *Conn) RemoteAddr() net.Addr { return c.netConn.RemoteAddr() } +// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). func (c *Conn) SetDeadline(t time.Time) error { - if c.isClosed() { - return errClosed + err := c.SetReadDeadline(t) + if err != nil { + return err } - - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - c.readDeadline = t - - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - c.writeDeadline = t - - return c.netConn.SetDeadline(t) + return c.SetWriteDeadline(t) } +// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. func (c *Conn) SetReadDeadline(t time.Time) error { if c.isClosed() { return errClosed @@ -146,6 +153,14 @@ func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadlineLock.Lock() defer c.readDeadlineLock.Unlock() + + if t == NonBlockingDeadline { + c.readNonblocking = true + t = time.Time{} + } else { + c.readNonblocking = false + } + c.readDeadline = t return c.netConn.SetReadDeadline(t) @@ -193,7 +208,7 @@ func (c *Conn) flush() error { n, err := c.nonblockingWrite(remainingBuf) remainingBuf = remainingBuf[n:] if err != nil { - if !errors.Is(err, errWouldBlock) { + if !errors.Is(err, ErrWouldBlock) { buf = buf[:len(remainingBuf)] copy(buf, remainingBuf) c.writeQueue.pushFront(buf) @@ -234,7 +249,7 @@ func (c *Conn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan err } if err != nil { - if !errors.Is(err, errWouldBlock) { + if !errors.Is(err, ErrWouldBlock) { errChan <- err return } @@ -276,7 +291,7 @@ func (c *Conn) fakeNonblockingWrite(b []byte) (n int, err error) { if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { - err = errWouldBlock + err = ErrWouldBlock } } }() @@ -305,7 +320,7 @@ func (c *Conn) fakeNonblockingRead(b []byte) (n int, err error) { if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { - err = errWouldBlock + err = ErrWouldBlock } } }() diff --git a/internal/nbbconn/nbbconn_test.go b/internal/nbbconn/nbbconn_test.go new file mode 100644 index 00000000..2898cd25 --- /dev/null +++ b/internal/nbbconn/nbbconn_test.go @@ -0,0 +1,129 @@ +package nbbconn_test + +import ( + "net" + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/nbbconn" + "github.com/stretchr/testify/require" +) + +func TestWriteIsBuffered(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + conn := nbbconn.New(local) + + // net.Pipe is synchronous so the Write would block if not buffered. + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + err := conn.Flush() + errChan <- err + }() + + readBuf := make([]byte, len(writeBuf)) + _, err = remote.Read(readBuf) + require.NoError(t, err) + + require.NoError(t, <-errChan) +} + +func TestReadFlushesWriteBuffer(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + conn := nbbconn.New(local) + + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 2) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + + _, err = remote.Write([]byte("okay")) + errChan <- err + }() + + readBuf := make([]byte, 4) + _, err = conn.Read(readBuf) + require.NoError(t, err) + require.Equal(t, []byte("okay"), readBuf) + + require.NoError(t, <-errChan) + require.NoError(t, <-errChan) +} + +func TestCloseFlushesWriteBuffer(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + conn := nbbconn.New(local) + + writeBuf := []byte("test") + n, err := conn.Write(writeBuf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + + errChan := make(chan error, 1) + go func() { + readBuf := make([]byte, len(writeBuf)) + _, err := remote.Read(readBuf) + errChan <- err + }() + + err = conn.Close() + require.NoError(t, err) + + require.NoError(t, <-errChan) +} + +func TestNonBlockingRead(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + conn := nbbconn.New(local) + + err := conn.SetReadDeadline(nbbconn.NonBlockingDeadline) + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.ErrorIs(t, err, nbbconn.ErrWouldBlock) + require.EqualValues(t, 0, n) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + err = conn.SetReadDeadline(time.Time{}) + require.NoError(t, err) + + n, err = conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) +} diff --git a/internal/nbbconn/queue.go.deleted b/internal/nbbconn/queue.go.deleted new file mode 100644 index 00000000..03f2b4a1 --- /dev/null +++ b/internal/nbbconn/queue.go.deleted @@ -0,0 +1,75 @@ +package nbbconn + +import ( + "sync" +) + +const minQueueLen = 8 + +type queue[T any] struct { + lock sync.Mutex + queue []T + r, w int +} + +func (q *queue[T]) pushBack(item T) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.w >= len(q.queue) { + q.growQueue() + } + q.queue[q.w] = item + q.w++ +} + +func (q *queue[T]) pushFront(item T) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.w >= len(q.queue) { + q.growQueue() + } + copy(q.queue[q.r+1:q.w+1], q.queue[q.r:q.w]) + q.queue[q.r] = item + q.w++ +} + +func (q *queue[T]) popFront() (T, bool) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.r == q.w { + var zero T + return zero, false + } + + item := q.queue[q.r] + + // Clear reference so it can be garbage collected. + var zero T + q.queue[q.r] = zero + + q.r++ + + if q.r == q.w { + q.r = 0 + q.w = 0 + if len(q.queue) > minQueueLen { + q.queue = make([]T, minQueueLen) + } + } + + return item, true +} + +func (q *queue[T]) growQueue() { + desiredLen := (len(q.queue) + 1) * 3 / 2 + if desiredLen < minQueueLen { + desiredLen = minQueueLen + } + + newQueue := make([]T, desiredLen) + copy(newQueue, q.queue) + q.queue = newQueue +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index f23b5009..57b298aa 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -13,9 +13,9 @@ import ( "net" "strconv" "strings" - "sync" "time" + "github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/nbbconn" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" @@ -76,11 +76,6 @@ type PgConn struct { status byte // One of connStatus* constants - bufferingReceive bool - bufferingReceiveMux sync.Mutex - bufferingReceiveMsg pgproto3.BackendMessage - bufferingReceiveErr error - peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -254,6 +249,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.conn = nbbconn.New(pgConn.conn) + pgConn.contextWatcher.Unwatch() // context watcher should watch nbbconn + pgConn.contextWatcher = newContextWatcher(pgConn.conn) defer pgConn.contextWatcher.Unwatch() @@ -388,24 +385,6 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (pgConn *PgConn) signalMessage() chan struct{} { - if pgConn.bufferingReceive { - panic("BUG: signalMessage when already in progress") - } - - pgConn.bufferingReceive = true - pgConn.bufferingReceiveMux.Lock() - - ch := make(chan struct{}) - go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() - pgConn.bufferingReceiveMux.Unlock() - close(ch) - }() - - return ch -} - // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -445,25 +424,13 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - var msg pgproto3.BackendMessage - var err error - if pgConn.bufferingReceive { - pgConn.bufferingReceiveMux.Lock() - msg = pgConn.bufferingReceiveMsg - err = pgConn.bufferingReceiveErr - pgConn.bufferingReceiveMux.Unlock() - pgConn.bufferingReceive = false - - // If a timeout error happened in the background try the read again. - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - msg, err = pgConn.frontend.Receive() - } - } else { - msg, err = pgConn.frontend.Receive() - } + msg, err := pgConn.frontend.Receive() if err != nil { + if errors.Is(err, nbbconn.ErrWouldBlock) { + return nil, err + } + // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -482,13 +449,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { - // Close on anything other than timeout error - everything else is fatal - var netErr net.Error - isNetErr := errors.As(err, &netErr) - if !(isNetErr && netErr.Timeout()) { - pgConn.asyncClose() - } - return nil, err } pgConn.peekedMsg = nil @@ -1176,62 +1136,57 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - // Send copy data - abortCopyChan := make(chan struct{}) - copyErrChan := make(chan error, 1) - signalMessageChan := pgConn.signalMessage() - senderDoneChan := make(chan struct{}) - - go func() { - defer close(senderDoneChan) - - buf := make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - - for { - n, readErr := r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) - if writeErr != nil { - // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. - pgConn.conn.Close() - - copyErrChan <- writeErr - return - } - } - if readErr != nil { - copyErrChan <- readErr - return - } - - select { - case <-abortCopyChan: - return - default: - } + err = pgConn.conn.SetReadDeadline(nbbconn.NonBlockingDeadline) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking := true + defer func() { + if nonblocking { + pgConn.conn.SetReadDeadline(time.Time{}) } }() - var pgErr error - var copyErr error - for copyErr == nil && pgErr == nil { - select { - case copyErr = <-copyErrChan: - case <-signalMessageChan: + buf := iobufpool.Get(65536) + buf[0] = 'd' + + var readErr, pgErr error + for { + // Read chunk from r. + var n int + n, readErr = r.Read(buf[5:cap(buf)]) + + // Send chunk to PostgreSQL. + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) + if writeErr != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + } + + // Abort loop if there was a read error. + if readErr != nil { + break + } + + // Read messages until error or none available. + for { msg, err := pgConn.receiveMessage() if err != nil { + if errors.Is(err, nbbconn.ErrWouldBlock) { + break + } pgConn.asyncClose() return CommandTag{}, preferContextOverNetTimeoutError(ctx, err) } @@ -1239,18 +1194,22 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - default: - signalMessageChan = pgConn.signalMessage() + break } } } - close(abortCopyChan) - <-senderDoneChan - if copyErr == io.EOF || pgErr != nil { + err = pgConn.conn.SetReadDeadline(time.Time{}) + if err != nil { + pgConn.asyncClose() + return CommandTag{}, err + } + nonblocking = false + + if readErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) } err = pgConn.frontend.Flush() if err != nil {