mirror of https://github.com/jackc/pgx.git
wip
parent
1bf3319330
commit
c525bf97cf
|
@ -11,9 +11,11 @@
|
|||
package nbconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
@ -173,8 +175,6 @@ func (c *NetConn) Write(b []byte) (n int, err error) {
|
|||
return 0, errClosed
|
||||
}
|
||||
|
||||
fmt.Println("NetConn Write", len(b))
|
||||
|
||||
buf := iobufpool.Get(len(b))
|
||||
copy(*buf, b)
|
||||
c.writeQueue.pushBack(buf)
|
||||
|
@ -279,6 +279,8 @@ func (c *NetConn) Flush() error {
|
|||
return c.flush()
|
||||
}
|
||||
|
||||
var LogWritten bytes.Buffer
|
||||
|
||||
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
|
||||
func (c *NetConn) flush() error {
|
||||
var stopChan chan struct{}
|
||||
|
@ -296,11 +298,7 @@ func (c *NetConn) flush() error {
|
|||
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
|
||||
remainingBuf := *buf
|
||||
for len(remainingBuf) > 0 {
|
||||
if len(remainingBuf) == 24 {
|
||||
fmt.Println("break")
|
||||
}
|
||||
n, err := c.nonblockingWrite(remainingBuf)
|
||||
fmt.Println("flush nonblockingWrite", n, err)
|
||||
remainingBuf = remainingBuf[n:]
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrWouldBlock) {
|
||||
|
@ -380,6 +378,9 @@ func (c *NetConn) isClosed() bool {
|
|||
}
|
||||
|
||||
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
|
||||
if rand.Float32() > 0.9 {
|
||||
// return 0, ErrWouldBlock
|
||||
}
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingWrite(b)
|
||||
} else {
|
||||
|
@ -391,6 +392,11 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
|||
c.writeDeadlineLock.Lock()
|
||||
defer c.writeDeadlineLock.Unlock()
|
||||
|
||||
defer func() {
|
||||
LogWritten.Write(b[:n])
|
||||
fmt.Println(n)
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(fakeNonblockingWriteWaitDuration)
|
||||
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
|
||||
err = c.conn.SetWriteDeadline(deadline)
|
||||
|
@ -413,12 +419,6 @@ func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
|
||||
defer func() {
|
||||
fmt.Println("nonblockingRead", n, err)
|
||||
if n == 0 {
|
||||
fmt.Println("break2")
|
||||
}
|
||||
}()
|
||||
if c.rawConn == nil {
|
||||
return c.fakeNonblockingRead(b)
|
||||
} else {
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package nbconn_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -67,6 +69,34 @@ pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
|
|||
2qWm8jTPeDC3sq+67s2oojHf+Q==
|
||||
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
|
||||
|
||||
type interceptConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
var logRead bytes.Buffer
|
||||
|
||||
func (c *interceptConn) Read(b []byte) (int, error) {
|
||||
n, err := c.Conn.Read(b)
|
||||
|
||||
logRead.Write(b[:n])
|
||||
fmt.Println(logRead.Len(), nbconn.LogWritten.Len())
|
||||
for i := 0; i < logRead.Len(); i++ {
|
||||
if logRead.Bytes()[i] != nbconn.LogWritten.Bytes()[i] {
|
||||
fmt.Println("mismatch at", i)
|
||||
fmt.Println(logRead.Bytes()[i-20 : i+20])
|
||||
fmt.Println(nbconn.LogWritten.Bytes()[i-20 : i+20])
|
||||
fmt.Println(
|
||||
bytes.Contains(nbconn.LogWritten.Bytes(), logRead.Bytes()[i:i+10]),
|
||||
bytes.Index(nbconn.LogWritten.Bytes(), logRead.Bytes()[i:i+10]),
|
||||
i-bytes.Index(nbconn.LogWritten.Bytes(), logRead.Bytes()[i:i+10]),
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
|
@ -121,7 +151,7 @@ func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote n
|
|||
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsServer := tls.Server(remote, &tls.Config{
|
||||
tlsServer := tls.Server(&interceptConn{remote}, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
serverTLSHandshakeChan := make(chan error)
|
||||
|
@ -278,6 +308,9 @@ func TestInternalNonBlockingWrite(t *testing.T) {
|
|||
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
writeBuf := make([]byte, deadlockSize)
|
||||
for i := range writeBuf {
|
||||
writeBuf[i] = 1
|
||||
}
|
||||
n, err := conn.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, deadlockSize, n)
|
||||
|
@ -295,6 +328,15 @@ func TestInternalNonBlockingWrite(t *testing.T) {
|
|||
|
||||
readBuf := make([]byte, deadlockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
|
||||
fmt.Println(logRead.Len(), nbconn.LogWritten.Len())
|
||||
for i := 0; i < logRead.Len(); i++ {
|
||||
if logRead.Bytes()[i] != nbconn.LogWritten.Bytes()[i] {
|
||||
fmt.Println("mismatch at", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
|
@ -336,36 +378,33 @@ func TestTLSCrash(t *testing.T) {
|
|||
remote = tlsServer
|
||||
|
||||
const blockSize = 4 * 1024 * 1024
|
||||
const blockCount = 16
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
for i := 0; i < blockCount; i++ {
|
||||
writeBuf := make([]byte, blockSize)
|
||||
_, err := remote.Write(writeBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
writeBuf := make([]byte, blockSize)
|
||||
_, err := remote.Write(writeBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
readBuf := make([]byte, blockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
readBuf := make([]byte, blockSize)
|
||||
_, err = io.ReadFull(remote, readBuf)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
errChan <- nil
|
||||
}()
|
||||
|
||||
for i := 0; i < blockCount; i++ {
|
||||
readBuf := make([]byte, blockSize)
|
||||
_, err = io.ReadFull(local, readBuf)
|
||||
for i := 0; i < 64; i++ {
|
||||
writeBuf := make([]byte, blockSize/64)
|
||||
_, err = local.Write(writeBuf)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeBuf := make([]byte, blockSize)
|
||||
_, err := local.Write(writeBuf)
|
||||
readBuf := make([]byte, blockSize/64)
|
||||
_, err = io.ReadFull(local, readBuf)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue