pgx/msg_reader_test.go

190 lines
3.5 KiB
Go

package pgx
import (
"bufio"
"net"
"testing"
"time"
"github.com/jackc/pgmock/pgmsg"
)
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
t.Parallel()
tests := []struct {
msgType byte
payloadSize int32
buffered bool
}{
{1, 50, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 24000, false},
{9, 4000, true},
{1, 1500, true},
{2, 0, true},
{3, 500, true},
{4, 1050, true},
{5, 1500, true},
{6, 1500, true},
{7, 4000, true},
{8, 14000, false},
{9, 0, true},
{1, 500, true},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for _, tt := range tests {
_, err = conn.Write([]byte{tt.msgType})
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, int(tt.payloadSize))
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
for i, tt := range tests {
msgType, err := mr.rxMsg()
if err != nil {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
if msgType != tt.msgType {
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
}
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
}
}
}
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
testCount := 10000
go func() {
var bigEndian pgmsg.BigEndianBuf
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
for i := 0; i < testCount; i++ {
msgType := byte(i)
_, err = conn.Write([]byte{msgType})
if err != nil {
t.Fatal(err)
}
msgSize := i % 4000
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
if err != nil {
t.Fatal(err)
}
payload := make([]byte, msgSize)
_, err = conn.Write(payload)
if err != nil {
t.Fatal(err)
}
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
mr := &msgReader{
reader: bufio.NewReader(conn),
shouldLog: func(int) bool { return false },
}
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
i := 0
for {
msgType, err := mr.rxMsg()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
continue
} else {
t.Fatalf("%d. Unexpected error: %v", i, err)
}
}
expectedMsgType := byte(i)
if msgType != expectedMsgType {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
}
expectedMsgSize := i % 4000
payload := mr.readBytes(mr.msgBytesRemaining)
if mr.err != nil {
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
}
if len(payload) != expectedMsgSize {
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
}
i++
if i == testCount {
break
}
}
}