mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
msgReader pre-buffers messages when possible
This commit is contained in:
parent
855276e2cf
commit
50b0bea9e5
@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||
@ -35,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) {
|
||||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
_, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
n, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||
r.msgBytesRemaining -= int32(n)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := r.reader.Peek(5)
|
||||
if err != nil {
|
||||
r.fatal(err)
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
msgType := b[0]
|
||||
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||
|
||||
// Try to preload bufio.Reader with entire message
|
||||
b, err = r.reader.Peek(5 + int(payloadSize))
|
||||
if err != nil && err != bufio.ErrBufferFull {
|
||||
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||
r.fatal(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.msgBytesRemaining = payloadSize
|
||||
r.reader.Discard(5)
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
|
189
msg_reader_test.go
Normal file
189
msg_reader_test.go
Normal file
@ -0,0 +1,189 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgmock/pgmsg"
|
||||
)
|
||||
|
||||
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
msgType byte
|
||||
payloadSize int32
|
||||
buffered bool
|
||||
}{
|
||||
{1, 50, true},
|
||||
{2, 0, true},
|
||||
{3, 500, true},
|
||||
{4, 1050, true},
|
||||
{5, 1500, true},
|
||||
{6, 1500, true},
|
||||
{7, 4000, true},
|
||||
{8, 24000, false},
|
||||
{9, 4000, true},
|
||||
{1, 1500, true},
|
||||
{2, 0, true},
|
||||
{3, 500, true},
|
||||
{4, 1050, true},
|
||||
{5, 1500, true},
|
||||
{6, 1500, true},
|
||||
{7, 4000, true},
|
||||
{8, 14000, false},
|
||||
{9, 0, true},
|
||||
{1, 500, true},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
var bigEndian pgmsg.BigEndianBuf
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
_, err = conn.Write([]byte{tt.msgType})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
payload := make([]byte, int(tt.payloadSize))
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mr := &msgReader{
|
||||
reader: bufio.NewReader(conn),
|
||||
shouldLog: func(int) bool { return false },
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
msgType, err := mr.rxMsg()
|
||||
if err != nil {
|
||||
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if msgType != tt.msgType {
|
||||
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
|
||||
}
|
||||
|
||||
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
|
||||
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
testCount := 10000
|
||||
|
||||
go func() {
|
||||
var bigEndian pgmsg.BigEndianBuf
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for i := 0; i < testCount; i++ {
|
||||
msgType := byte(i)
|
||||
|
||||
_, err = conn.Write([]byte{msgType})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msgSize := i % 4000
|
||||
|
||||
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
payload := make([]byte, msgSize)
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mr := &msgReader{
|
||||
reader: bufio.NewReader(conn),
|
||||
shouldLog: func(int) bool { return false },
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
|
||||
i := 0
|
||||
for {
|
||||
msgType, err := mr.rxMsg()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
continue
|
||||
} else {
|
||||
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
expectedMsgType := byte(i)
|
||||
if msgType != expectedMsgType {
|
||||
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
|
||||
}
|
||||
|
||||
expectedMsgSize := i % 4000
|
||||
payload := mr.readBytes(mr.msgBytesRemaining)
|
||||
if mr.err != nil {
|
||||
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
|
||||
}
|
||||
if len(payload) != expectedMsgSize {
|
||||
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
|
||||
}
|
||||
|
||||
i++
|
||||
if i == testCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user