Fix reading interrupted messages

When an message is received and a timeout occurs after reading the
header but before reading the entire body the connection state could
be corrupted due to the header being consumed. The next read would
consider the body of the previous message as the header for the next.

fixes #348
pull/359/merge
Jack Christensen 2017-12-16 13:45:22 -06:00
parent 1ed4024c70
commit cbb3fa5ecc
4 changed files with 135 additions and 18 deletions

View File

@ -24,6 +24,10 @@ type Backend struct {
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
}
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
@ -57,16 +61,19 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
}
func (b *Backend) Receive() (FrontendMessage, error) {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
}
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
var msg FrontendMessage
switch msgType {
switch b.msgType {
case 'B':
msg = &b.bind
case 'C':
@ -88,14 +95,16 @@ func (b *Backend) Receive() (FrontendMessage, error) {
case 'X':
msg = &b.terminate
default:
return nil, errors.Errorf("unknown message type: %c", msgType)
return nil, errors.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(bodyLen)
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, err
}
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
}

37
pgproto3/backend_test.go Normal file
View File

@ -0,0 +1,37 @@
package pgproto3_test
import (
"testing"
"github.com/jackc/pgx/pgproto3"
)
func TestBackendReceiveInterrupted(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 0, 6})
backend, err := pgproto3.NewBackend(server, nil)
if err != nil {
t.Fatal(err)
}
msg, err := backend.Receive()
if err == nil {
t.Fatal("expected err")
}
if msg != nil {
t.Fatalf("did not expect msg, but %v", msg)
}
server.push([]byte{'I', 0})
msg, err = backend.Receive()
if err != nil {
t.Fatal(err)
}
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
t.Fatalf("unexpected msg: %v", msg)
}
}

View File

@ -34,6 +34,10 @@ type Frontend struct {
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
bodyLen int
msgType byte
partialMsg bool
}
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
@ -47,16 +51,19 @@ func (b *Frontend) Send(msg FrontendMessage) error {
}
func (b *Frontend) Receive() (BackendMessage, error) {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, err
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
}
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
var msg BackendMessage
switch msgType {
switch b.msgType {
case '1':
msg = &b.parseComplete
case '2':
@ -100,14 +107,16 @@ func (b *Frontend) Receive() (BackendMessage, error) {
case 'Z':
msg = &b.readyForQuery
default:
return nil, errors.Errorf("unknown message type: %c", msgType)
return nil, errors.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(bodyLen)
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, err
}
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
}

62
pgproto3/frontend_test.go Normal file
View File

@ -0,0 +1,62 @@
package pgproto3_test
import (
"testing"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgproto3"
)
type interruptReader struct {
chunks [][]byte
}
func (ir *interruptReader) Read(p []byte) (n int, err error) {
if len(ir.chunks) == 0 {
return 0, errors.New("no data")
}
n = copy(p, ir.chunks[0])
if n != len(ir.chunks[0]) {
panic("this test reader doesn't support partial reads of chunks")
}
ir.chunks = ir.chunks[1:]
return n, nil
}
func (ir *interruptReader) push(p []byte) {
ir.chunks = append(ir.chunks, p)
}
func TestFrontendReceiveInterrupted(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Z', 0, 0, 0, 5})
frontend, err := pgproto3.NewFrontend(server, nil)
if err != nil {
t.Fatal(err)
}
msg, err := frontend.Receive()
if err == nil {
t.Fatal("expected err")
}
if msg != nil {
t.Fatalf("did not expect msg, but %v", msg)
}
server.push([]byte{'I'})
msg, err = frontend.Receive()
if err != nil {
t.Fatal(err)
}
if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
t.Fatalf("unexpected msg: %v", msg)
}
}