pgx/pgproto3/backend.go
Jack Christensen cbb3fa5ecc Fix reading interrupted messages
When an message is received and a timeout occurs after reading the
header but before reading the entire body the connection state could
be corrupted due to the header being consumed. The next read would
consider the body of the previous message as the header for the next.

fixes #348
2017-12-16 13:45:22 -06:00

111 lines
1.9 KiB
Go

package pgproto3
import (
"encoding/binary"
"io"
"github.com/jackc/pgx/chunkreader"
"github.com/pkg/errors"
)
type Backend struct {
cr *chunkreader.ChunkReader
w io.Writer
// Frontend message flyweights
bind Bind
_close Close
describe Describe
execute Execute
flush Flush
parse Parse
passwordMessage PasswordMessage
query Query
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
}
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
cr := chunkreader.NewChunkReader(r)
return &Backend{cr: cr, w: w}, nil
}
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
}
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, err
}
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
}
func (b *Backend) Receive() (FrontendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
}
var msg FrontendMessage
switch b.msgType {
case 'B':
msg = &b.bind
case 'C':
msg = &b._close
case 'D':
msg = &b.describe
case 'E':
msg = &b.execute
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
msg = &b.passwordMessage
case 'Q':
msg = &b.query
case 'S':
msg = &b.sync
case 'X':
msg = &b.terminate
default:
return nil, errors.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, err
}
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
}