tmp-nbconn-win
Jack Christensen 2023-02-26 07:25:30 -06:00
parent 1bf3319330
commit c525bf97cf
2 changed files with 71 additions and 32 deletions

View File

@ -11,9 +11,11 @@
package nbconn
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"math/rand"
"net"
"os"
"sync"
@ -173,8 +175,6 @@ func (c *NetConn) Write(b []byte) (n int, err error) {
return 0, errClosed
}
fmt.Println("NetConn Write", len(b))
buf := iobufpool.Get(len(b))
copy(*buf, b)
c.writeQueue.pushBack(buf)
@ -279,6 +279,8 @@ func (c *NetConn) Flush() error {
return c.flush()
}
var LogWritten bytes.Buffer
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *NetConn) flush() error {
var stopChan chan struct{}
@ -296,11 +298,7 @@ func (c *NetConn) flush() error {
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := *buf
for len(remainingBuf) > 0 {
if len(remainingBuf) == 24 {
fmt.Println("break")
}
n, err := c.nonblockingWrite(remainingBuf)
fmt.Println("flush nonblockingWrite", n, err)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
@ -380,6 +378,9 @@ func (c *NetConn) isClosed() bool {
}
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
if rand.Float32() > 0.9 {
// return 0, ErrWouldBlock
}
if c.rawConn == nil {
return c.fakeNonblockingWrite(b)
} else {
@ -391,6 +392,11 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
defer func() {
LogWritten.Write(b[:n])
fmt.Println(n)
}()
deadline := time.Now().Add(fakeNonblockingWriteWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.conn.SetWriteDeadline(deadline)
@ -413,12 +419,6 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
}
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
defer func() {
fmt.Println("nonblockingRead", n, err)
if n == 0 {
fmt.Println("break2")
}
}()
if c.rawConn == nil {
return c.fakeNonblockingRead(b)
} else {

View File

@ -1,10 +1,12 @@
package nbconn_test
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"os"
"strings"
"testing"
"time"
@ -67,6 +69,34 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
2qWm8jTPeDC3sq+67s2oojHf+Q==
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
type interceptConn struct {
net.Conn
}
var logRead bytes.Buffer
func (c *interceptConn) Read(b []byte) (int, error) {
n, err := c.Conn.Read(b)
logRead.Write(b[:n])
fmt.Println(logRead.Len(), nbconn.LogWritten.Len())
for i := 0; i < logRead.Len(); i++ {
if logRead.Bytes()[i] != nbconn.LogWritten.Bytes()[i] {
fmt.Println("mismatch at", i)
fmt.Println(logRead.Bytes()[i-20 : i+20])
fmt.Println(nbconn.LogWritten.Bytes()[i-20 : i+20])
fmt.Println(
bytes.Contains(nbconn.LogWritten.Bytes(), logRead.Bytes()[i:i+10]),
bytes.Index(nbconn.LogWritten.Bytes(), logRead.Bytes()[i:i+10]),
i-bytes.Index(nbconn.LogWritten.Bytes(), logRead.Bytes()[i:i+10]),
)
os.Exit(1)
}
}
return n, err
}
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
for _, tt := range []struct {
name string
@ -121,7 +151,7 @@ func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote n
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
require.NoError(t, err)
tlsServer := tls.Server(remote, &tls.Config{
tlsServer := tls.Server(&interceptConn{remote}, &tls.Config{
Certificates: []tls.Certificate{cert},
})
serverTLSHandshakeChan := make(chan error)
@ -278,6 +308,9 @@ func TestInternalNonBlockingWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
for i := range writeBuf {
writeBuf[i] = 1
}
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
@ -295,6 +328,15 @@ func TestInternalNonBlockingWrite(t *testing.T) {
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(remote, readBuf)
fmt.Println(logRead.Len(), nbconn.LogWritten.Len())
for i := 0; i < logRead.Len(); i++ {
if logRead.Bytes()[i] != nbconn.LogWritten.Bytes()[i] {
fmt.Println("mismatch at", i)
break
}
}
errChan <- err
}()
@ -336,36 +378,33 @@ func TestTLSCrash(t *testing.T) {
remote = tlsServer
const blockSize = 4 * 1024 * 1024
const blockCount = 16
errChan := make(chan error, 1)
go func() {
for i := 0; i < blockCount; i++ {
writeBuf := make([]byte, blockSize)
_, err := remote.Write(writeBuf)
if err != nil {
errChan <- err
return
}
writeBuf := make([]byte, blockSize)
_, err := remote.Write(writeBuf)
if err != nil {
errChan <- err
return
}
readBuf := make([]byte, blockSize)
_, err = io.ReadFull(remote, readBuf)
if err != nil {
errChan <- err
return
}
readBuf := make([]byte, blockSize)
_, err = io.ReadFull(remote, readBuf)
if err != nil {
errChan <- err
return
}
errChan <- nil
}()
for i := 0; i < blockCount; i++ {
readBuf := make([]byte, blockSize)
_, err = io.ReadFull(local, readBuf)
for i := 0; i < 64; i++ {
writeBuf := make([]byte, blockSize/64)
_, err = local.Write(writeBuf)
require.NoError(t, err)
writeBuf := make([]byte, blockSize)
_, err := local.Write(writeBuf)
readBuf := make([]byte, blockSize/64)
_, err = io.ReadFull(local, readBuf)
require.NoError(t, err)
}