diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 5d07b65f..bb27d0ec 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -67,17 +67,23 @@ func (c *Conn) Read(b []byte) (n int, err error) { return 0, err } - buf := c.readQueue.popFront() - if buf != nil { - n = copy(b, buf) - if n < len(buf) { - buf = buf[n:] + for n < len(b) { + buf := c.readQueue.popFront() + if buf == nil { + break + } + copiedN := copy(b[n:], buf) + if copiedN < len(buf) { + buf = buf[copiedN:] c.readQueue.pushFront(buf) } else { releaseBuf(buf) } + n += copiedN + } + + if n == len(b) { return n, nil - // TODO - must return error if n != len(b) } var readNonblocking bool @@ -85,11 +91,14 @@ func (c *Conn) Read(b []byte) (n int, err error) { readNonblocking = c.readNonblocking c.readDeadlineLock.Unlock() + var readN int if readNonblocking { - return c.nonblockingRead(b) + readN, err = c.nonblockingRead(b[n:]) } else { - return c.netConn.Read(b) + readN, err = c.netConn.Read(b[n:]) } + n += readN + return n, err } // Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index fd5921a2..334f36ee 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -127,3 +127,192 @@ func TestNonBlockingRead(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 4, n) } + +func TestReadPreviouslyBuffered(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + conn := nbconn.New(local) + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flust must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 5) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 5, n) + require.Equal(t, []byte("alpha"), readBuf) +} + +func TestReadPreviouslyBufferedPartialRead(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + conn := nbconn.New(local) + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flust must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 2) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.Equal(t, []byte("al"), readBuf) + + readBuf = make([]byte, 3) + n, err = conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 3, n) + require.Equal(t, []byte("pha"), readBuf) +} + +func TestReadMultiplePreviouslyBuffered(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + conn := nbconn.New(local) + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flust must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + readBuf := make([]byte, 9) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) +} + +func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { + local, remote := net.Pipe() + defer func() { + local.Close() + remote.Close() + }() + + flushCompleteChan := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + err := func() error { + _, err := remote.Write([]byte("alpha")) + if err != nil { + return err + } + + readBuf := make([]byte, 4) + _, err = remote.Read(readBuf) + if err != nil { + return err + } + + <-flushCompleteChan + + _, err = remote.Write([]byte("beta")) + if err != nil { + return err + } + + return nil + }() + errChan <- err + }() + + conn := nbconn.New(local) + + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + + // Because net.Pipe() is synchronous conn.Flust must buffer a read. + err = conn.Flush() + require.NoError(t, err) + + close(flushCompleteChan) + + readBuf := make([]byte, 9) + n, err := conn.Read(readBuf) + require.NoError(t, err) + require.EqualValues(t, 9, n) + require.Equal(t, []byte("alphabeta"), readBuf) +}