mirror of
https://github.com/jackc/pgx.git
synced 2025-07-14 16:20:15 +00:00
Reduce allocations and copies in pgproto3
Altered chunkreader to never reuse memory. Altered pgproto3 to to copy memory when decoding. Renamed UnmarshalBinary to Decode because of changed semantics.
This commit is contained in:
parent
70b7c9a300
commit
e8eaad520b
@ -9,14 +9,12 @@ type ChunkReader struct {
|
||||
|
||||
buf []byte
|
||||
rp, wp int // buf read position and write position
|
||||
taken bool
|
||||
|
||||
options Options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
MinBufLen int // Minimum buffer length
|
||||
BlockLen int // Increments to expand buffer (e.g. a 8000 byte request with a BlockLen of 1024 would yield a buffer len of 8192)
|
||||
}
|
||||
|
||||
func NewChunkReader(r io.Reader) *ChunkReader {
|
||||
@ -32,9 +30,6 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
||||
if options.MinBufLen == 0 {
|
||||
options.MinBufLen = 4096
|
||||
}
|
||||
if options.BlockLen == 0 {
|
||||
options.BlockLen = 512
|
||||
}
|
||||
|
||||
return &ChunkReader{
|
||||
r: r,
|
||||
@ -43,8 +38,8 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Next returns buf filled with the next n bytes. buf is only valid until the
|
||||
// next call to Next. If an error occurs, buf will be nil.
|
||||
// Next returns buf filled with the next n bytes. If an error occurs, buf will
|
||||
// be nil.
|
||||
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
@ -56,17 +51,12 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
// available space in buf is less than n
|
||||
if len(r.buf) < n {
|
||||
r.copyBufContents(r.newBuf(n))
|
||||
r.taken = false
|
||||
}
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
newBuf := r.buf
|
||||
if r.taken {
|
||||
newBuf = r.newBuf(n)
|
||||
r.taken = false
|
||||
}
|
||||
newBuf := r.newBuf(n)
|
||||
r.copyBufContents(newBuf)
|
||||
}
|
||||
|
||||
@ -79,20 +69,13 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// KeepLast prevents the last data retrieved by Next from being reused by the
|
||||
// ChunkReader.
|
||||
func (r *ChunkReader) KeepLast() {
|
||||
r.taken = true
|
||||
}
|
||||
|
||||
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
||||
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
||||
r.wp += n
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ChunkReader) newBuf(min int) []byte {
|
||||
size := ((min / r.options.BlockLen) + 1) * r.options.BlockLen
|
||||
func (r *ChunkReader) newBuf(size int) []byte {
|
||||
if size < r.options.MinBufLen {
|
||||
size = r.options.MinBufLen
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2})
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -44,7 +44,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
|
||||
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2})
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -59,14 +59,14 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
|
||||
if bytes.Compare(n1, src[0:5]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
|
||||
}
|
||||
if len(r.buf) != 6 {
|
||||
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 6, len(r.buf))
|
||||
if len(r.buf) != 5 {
|
||||
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderNextReusesBuf(t *testing.T) {
|
||||
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1})
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -90,38 +90,6 @@ func TestChunkReaderNextReusesBuf(t *testing.T) {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(n1, src[4:8]) != 0 {
|
||||
t.Fatalf("Expected Next to have reused buf, %v found instead of %v", src[4:8], n1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderKeepLastPreventsBufReuse(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
|
||||
}
|
||||
r.KeepLast()
|
||||
|
||||
n2, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[4:8]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ type Authentication struct {
|
||||
|
||||
func (*Authentication) Backend() {}
|
||||
|
||||
func (dst *Authentication) UnmarshalBinary(src []byte) error {
|
||||
func (dst *Authentication) Decode(src []byte) error {
|
||||
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
|
||||
|
||||
switch dst.Type {
|
||||
|
@ -13,7 +13,7 @@ type BackendKeyData struct {
|
||||
|
||||
func (*BackendKeyData) Backend() {}
|
||||
|
||||
func (dst *BackendKeyData) UnmarshalBinary(src []byte) error {
|
||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ type BindComplete struct{}
|
||||
|
||||
func (*BindComplete) Backend() {}
|
||||
|
||||
func (dst *BindComplete) UnmarshalBinary(src []byte) error {
|
||||
func (dst *BindComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ type CloseComplete struct{}
|
||||
|
||||
func (*CloseComplete) Backend() {}
|
||||
|
||||
func (dst *CloseComplete) UnmarshalBinary(src []byte) error {
|
||||
func (dst *CloseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
@ -11,14 +11,13 @@ type CommandComplete struct {
|
||||
|
||||
func (*CommandComplete) Backend() {}
|
||||
|
||||
func (dst *CommandComplete) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
func (dst *CommandComplete) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||
}
|
||||
dst.CommandTag = string(b[:len(b)-1])
|
||||
|
||||
dst.CommandTag = string(src[:idx])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ type CopyBothResponse struct {
|
||||
|
||||
func (*CopyBothResponse) Backend() {}
|
||||
|
||||
func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error {
|
||||
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
|
@ -13,9 +13,8 @@ type CopyData struct {
|
||||
func (*CopyData) Backend() {}
|
||||
func (*CopyData) Frontend() {}
|
||||
|
||||
func (dst *CopyData) UnmarshalBinary(src []byte) error {
|
||||
dst.Data = make([]byte, len(src))
|
||||
copy(dst.Data, src)
|
||||
func (dst *CopyData) Decode(src []byte) error {
|
||||
dst.Data = src
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ type CopyInResponse struct {
|
||||
|
||||
func (*CopyInResponse) Backend() {}
|
||||
|
||||
func (dst *CopyInResponse) UnmarshalBinary(src []byte) error {
|
||||
func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
|
@ -13,7 +13,7 @@ type CopyOutResponse struct {
|
||||
|
||||
func (*CopyOutResponse) Backend() {}
|
||||
|
||||
func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error {
|
||||
func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
|
@ -13,35 +13,42 @@ type DataRow struct {
|
||||
|
||||
func (*DataRow) Backend() {}
|
||||
|
||||
func (dst *DataRow) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
func (dst *DataRow) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
rp := 0
|
||||
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
dst.Values = make([][]byte, fieldCount)
|
||||
// If the capacity of the values slice is too small OR substantially too
|
||||
// large reallocate. This is too avoid one row with many columns from
|
||||
// permanently allocating memory.
|
||||
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||
dst.Values = make([][]byte, fieldCount, 32)
|
||||
} else {
|
||||
dst.Values = dst.Values[:fieldCount]
|
||||
}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
if buf.Len() < 4 {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4))))
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
continue
|
||||
}
|
||||
dst.Values[i] = nil
|
||||
} else {
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
value := make([]byte, msgSize)
|
||||
_, err := buf.Read(value)
|
||||
if err != nil {
|
||||
return err
|
||||
dst.Values[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
|
||||
dst.Values[i] = value
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -8,7 +8,7 @@ type EmptyQueryResponse struct{}
|
||||
|
||||
func (*EmptyQueryResponse) Backend() {}
|
||||
|
||||
func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error {
|
||||
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ type ErrorResponse struct {
|
||||
|
||||
func (*ErrorResponse) Backend() {}
|
||||
|
||||
func (dst *ErrorResponse) UnmarshalBinary(src []byte) error {
|
||||
func (dst *ErrorResponse) Decode(src []byte) error {
|
||||
*dst = ErrorResponse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
@ -108,6 +108,6 @@ func (b *Frontend) Receive() (BackendMessage, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = msg.UnmarshalBinary(msgBody)
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
|
@ -13,20 +13,24 @@ type FunctionCallResponse struct {
|
||||
|
||||
func (*FunctionCallResponse) Backend() {}
|
||||
|
||||
func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 4 {
|
||||
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
resultSize := int(binary.BigEndian.Uint32(buf.Next(4)))
|
||||
if buf.Len() != resultSize {
|
||||
rp := 0
|
||||
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
if resultSize == -1 {
|
||||
dst.Result = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src[rp:]) != resultSize {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
|
||||
dst.Result = make([]byte, resultSize)
|
||||
copy(dst.Result, buf.Bytes())
|
||||
|
||||
dst.Result = src[rp:]
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@ type NoData struct{}
|
||||
|
||||
func (*NoData) Backend() {}
|
||||
|
||||
func (dst *NoData) UnmarshalBinary(src []byte) error {
|
||||
func (dst *NoData) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
@ -4,8 +4,8 @@ type NoticeResponse ErrorResponse
|
||||
|
||||
func (*NoticeResponse) Backend() {}
|
||||
|
||||
func (dst *NoticeResponse) UnmarshalBinary(src []byte) error {
|
||||
return (*ErrorResponse)(dst).UnmarshalBinary(src)
|
||||
func (dst *NoticeResponse) Decode(src []byte) error {
|
||||
return (*ErrorResponse)(dst).Decode(src)
|
||||
}
|
||||
|
||||
func (src *NoticeResponse) MarshalBinary() ([]byte, error) {
|
||||
|
@ -14,7 +14,7 @@ type NotificationResponse struct {
|
||||
|
||||
func (*NotificationResponse) Backend() {}
|
||||
|
||||
func (dst *NotificationResponse) UnmarshalBinary(src []byte) error {
|
||||
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
@ -12,7 +12,7 @@ type ParameterDescription struct {
|
||||
|
||||
func (*ParameterDescription) Backend() {}
|
||||
|
||||
func (dst *ParameterDescription) UnmarshalBinary(src []byte) error {
|
||||
func (dst *ParameterDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
|
@ -13,7 +13,7 @@ type ParameterStatus struct {
|
||||
|
||||
func (*ParameterStatus) Backend() {}
|
||||
|
||||
func (dst *ParameterStatus) UnmarshalBinary(src []byte) error {
|
||||
func (dst *ParameterStatus) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
|
@ -8,7 +8,7 @@ type ParseComplete struct{}
|
||||
|
||||
func (*ParseComplete) Backend() {}
|
||||
|
||||
func (dst *ParseComplete) UnmarshalBinary(src []byte) error {
|
||||
func (dst *ParseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
@ -2,8 +2,13 @@ package pgproto3
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Message is the interface implemented by an object that can decode and encode
|
||||
// a particular PostgreSQL message.
|
||||
//
|
||||
// Decode is allowed and expected to retain a reference to data after
|
||||
// returning (unlike encoding.BinaryUnmarshaler).
|
||||
type Message interface {
|
||||
UnmarshalBinary(data []byte) error
|
||||
Decode(data []byte) error
|
||||
MarshalBinary() (data []byte, err error)
|
||||
}
|
||||
|
||||
@ -17,58 +22,6 @@ type BackendMessage interface {
|
||||
Backend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) {
|
||||
// switch typeByte {
|
||||
// case '1':
|
||||
// return ParseParseComplete(body)
|
||||
// case '2':
|
||||
// return ParseBindComplete(body)
|
||||
// case 'C':
|
||||
// return ParseCommandComplete(body)
|
||||
// case 'D':
|
||||
// return ParseDataRow(body)
|
||||
// case 'E':
|
||||
// return ParseErrorResponse(body)
|
||||
// case 'K':
|
||||
// return ParseBackendKeyData(body)
|
||||
// case 'R':
|
||||
// return ParseAuthentication(body)
|
||||
// case 'S':
|
||||
// return ParseParameterStatus(body)
|
||||
// case 'T':
|
||||
// return ParseRowDescription(body)
|
||||
// case 't':
|
||||
// return ParseParameterDescription(body)
|
||||
// case 'Z':
|
||||
// return ParseReadyForQuery(body)
|
||||
// default:
|
||||
// return ParseUnknownMessage(typeByte, body)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) {
|
||||
// switch typeByte {
|
||||
// case 'B':
|
||||
// return ParseBind(body)
|
||||
// case 'D':
|
||||
// return ParseDescribe(body)
|
||||
// case 'E':
|
||||
// return ParseExecute(body)
|
||||
// case 'P':
|
||||
// return ParseParse(body)
|
||||
// case 'p':
|
||||
// return ParsePasswordMessage(body)
|
||||
// case 'Q':
|
||||
// return ParseQuery(body)
|
||||
// case 'S':
|
||||
// return ParseSync(body)
|
||||
// case 'X':
|
||||
// return ParseTerminate(body)
|
||||
// default:
|
||||
// return ParseUnknownMessage(typeByte, body)
|
||||
// }
|
||||
// }
|
||||
|
||||
type invalidMessageLenErr struct {
|
||||
messageType string
|
||||
expectedLen int
|
||||
|
@ -11,7 +11,7 @@ type Query struct {
|
||||
|
||||
func (*Query) Frontend() {}
|
||||
|
||||
func (dst *Query) UnmarshalBinary(src []byte) error {
|
||||
func (dst *Query) Decode(src []byte) error {
|
||||
i := bytes.IndexByte(src, 0)
|
||||
if i != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Query"}
|
||||
|
@ -10,7 +10,7 @@ type ReadyForQuery struct {
|
||||
|
||||
func (*ReadyForQuery) Backend() {}
|
||||
|
||||
func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error {
|
||||
func (dst *ReadyForQuery) Decode(src []byte) error {
|
||||
if len(src) != 1 {
|
||||
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ type RowDescription struct {
|
||||
|
||||
func (*RowDescription) Backend() {}
|
||||
|
||||
func (dst *RowDescription) UnmarshalBinary(src []byte) error {
|
||||
func (dst *RowDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
|
Loading…
x
Reference in New Issue
Block a user