mirror of
https://github.com/jackc/pgx.git
synced 2025-05-22 23:40:56 +00:00
When an message is received and a timeout occurs after reading the header but before reading the entire body the connection state could be corrupted due to the header being consumed. The next read would consider the body of the previous message as the header for the next. fixes #348
111 lines
1.9 KiB
Go
111 lines
1.9 KiB
Go
package pgproto3
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
|
|
"github.com/jackc/pgx/chunkreader"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type Backend struct {
|
|
cr *chunkreader.ChunkReader
|
|
w io.Writer
|
|
|
|
// Frontend message flyweights
|
|
bind Bind
|
|
_close Close
|
|
describe Describe
|
|
execute Execute
|
|
flush Flush
|
|
parse Parse
|
|
passwordMessage PasswordMessage
|
|
query Query
|
|
startupMessage StartupMessage
|
|
sync Sync
|
|
terminate Terminate
|
|
|
|
bodyLen int
|
|
msgType byte
|
|
partialMsg bool
|
|
}
|
|
|
|
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
|
cr := chunkreader.NewChunkReader(r)
|
|
return &Backend{cr: cr, w: w}, nil
|
|
}
|
|
|
|
func (b *Backend) Send(msg BackendMessage) error {
|
|
_, err := b.w.Write(msg.Encode(nil))
|
|
return err
|
|
}
|
|
|
|
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
|
buf, err := b.cr.Next(4)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
|
|
|
buf, err = b.cr.Next(msgSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = b.startupMessage.Decode(buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &b.startupMessage, nil
|
|
}
|
|
|
|
func (b *Backend) Receive() (FrontendMessage, error) {
|
|
if !b.partialMsg {
|
|
header, err := b.cr.Next(5)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b.msgType = header[0]
|
|
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
|
b.partialMsg = true
|
|
}
|
|
|
|
var msg FrontendMessage
|
|
switch b.msgType {
|
|
case 'B':
|
|
msg = &b.bind
|
|
case 'C':
|
|
msg = &b._close
|
|
case 'D':
|
|
msg = &b.describe
|
|
case 'E':
|
|
msg = &b.execute
|
|
case 'H':
|
|
msg = &b.flush
|
|
case 'P':
|
|
msg = &b.parse
|
|
case 'p':
|
|
msg = &b.passwordMessage
|
|
case 'Q':
|
|
msg = &b.query
|
|
case 'S':
|
|
msg = &b.sync
|
|
case 'X':
|
|
msg = &b.terminate
|
|
default:
|
|
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
|
}
|
|
|
|
msgBody, err := b.cr.Next(b.bodyLen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b.partialMsg = false
|
|
|
|
err = msg.Decode(msgBody)
|
|
return msg, err
|
|
}
|