Fix reading interrupted messages

When an message is received and a timeout occurs after reading the
header but before reading the entire body the connection state could
be corrupted due to the header being consumed. The next read would
consider the body of the previous message as the header for the next.

fixes #348
pull/359/merge
Jack Christensen 2017-12-16 13:45:22 -06:00
parent 1ed4024c70
commit cbb3fa5ecc
4 changed files with 135 additions and 18 deletions

View File

@ -24,6 +24,10 @@ type Backend struct {
startupMessage StartupMessage startupMessage StartupMessage
sync Sync sync Sync
terminate Terminate terminate Terminate
bodyLen int
msgType byte
partialMsg bool
} }
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
@ -57,16 +61,19 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
} }
func (b *Backend) Receive() (FrontendMessage, error) { func (b *Backend) Receive() (FrontendMessage, error) {
header, err := b.cr.Next(5) if !b.partialMsg {
if err != nil { header, err := b.cr.Next(5)
return nil, err if err != nil {
return nil, err
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
} }
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
var msg FrontendMessage var msg FrontendMessage
switch msgType { switch b.msgType {
case 'B': case 'B':
msg = &b.bind msg = &b.bind
case 'C': case 'C':
@ -88,14 +95,16 @@ func (b *Backend) Receive() (FrontendMessage, error) {
case 'X': case 'X':
msg = &b.terminate msg = &b.terminate
default: default:
return nil, errors.Errorf("unknown message type: %c", msgType) return nil, errors.Errorf("unknown message type: %c", b.msgType)
} }
msgBody, err := b.cr.Next(bodyLen) msgBody, err := b.cr.Next(b.bodyLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.partialMsg = false
err = msg.Decode(msgBody) err = msg.Decode(msgBody)
return msg, err return msg, err
} }

37
pgproto3/backend_test.go Normal file
View File

@ -0,0 +1,37 @@
package pgproto3_test
import (
"testing"
"github.com/jackc/pgx/pgproto3"
)
func TestBackendReceiveInterrupted(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 0, 6})
backend, err := pgproto3.NewBackend(server, nil)
if err != nil {
t.Fatal(err)
}
msg, err := backend.Receive()
if err == nil {
t.Fatal("expected err")
}
if msg != nil {
t.Fatalf("did not expect msg, but %v", msg)
}
server.push([]byte{'I', 0})
msg, err = backend.Receive()
if err != nil {
t.Fatal(err)
}
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
t.Fatalf("unexpected msg: %v", msg)
}
}

View File

@ -34,6 +34,10 @@ type Frontend struct {
parseComplete ParseComplete parseComplete ParseComplete
readyForQuery ReadyForQuery readyForQuery ReadyForQuery
rowDescription RowDescription rowDescription RowDescription
bodyLen int
msgType byte
partialMsg bool
} }
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
@ -47,16 +51,19 @@ func (b *Frontend) Send(msg FrontendMessage) error {
} }
func (b *Frontend) Receive() (BackendMessage, error) { func (b *Frontend) Receive() (BackendMessage, error) {
header, err := b.cr.Next(5) if !b.partialMsg {
if err != nil { header, err := b.cr.Next(5)
return nil, err if err != nil {
return nil, err
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
} }
msgType := header[0]
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
var msg BackendMessage var msg BackendMessage
switch msgType { switch b.msgType {
case '1': case '1':
msg = &b.parseComplete msg = &b.parseComplete
case '2': case '2':
@ -100,14 +107,16 @@ func (b *Frontend) Receive() (BackendMessage, error) {
case 'Z': case 'Z':
msg = &b.readyForQuery msg = &b.readyForQuery
default: default:
return nil, errors.Errorf("unknown message type: %c", msgType) return nil, errors.Errorf("unknown message type: %c", b.msgType)
} }
msgBody, err := b.cr.Next(bodyLen) msgBody, err := b.cr.Next(b.bodyLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.partialMsg = false
err = msg.Decode(msgBody) err = msg.Decode(msgBody)
return msg, err return msg, err
} }

62
pgproto3/frontend_test.go Normal file
View File

@ -0,0 +1,62 @@
package pgproto3_test
import (
"testing"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgproto3"
)
type interruptReader struct {
chunks [][]byte
}
func (ir *interruptReader) Read(p []byte) (n int, err error) {
if len(ir.chunks) == 0 {
return 0, errors.New("no data")
}
n = copy(p, ir.chunks[0])
if n != len(ir.chunks[0]) {
panic("this test reader doesn't support partial reads of chunks")
}
ir.chunks = ir.chunks[1:]
return n, nil
}
func (ir *interruptReader) push(p []byte) {
ir.chunks = append(ir.chunks, p)
}
func TestFrontendReceiveInterrupted(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Z', 0, 0, 0, 5})
frontend, err := pgproto3.NewFrontend(server, nil)
if err != nil {
t.Fatal(err)
}
msg, err := frontend.Receive()
if err == nil {
t.Fatal("expected err")
}
if msg != nil {
t.Fatalf("did not expect msg, but %v", msg)
}
server.push([]byte{'I'})
msg, err = frontend.Receive()
if err != nil {
t.Fatal(err)
}
if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
t.Fatalf("unexpected msg: %v", msg)
}
}