From cbb3fa5ecc6ad01b99c628a13b078862d7fb627c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Dec 2017 13:45:22 -0600 Subject: [PATCH] Fix reading interrupted messages When an message is received and a timeout occurs after reading the header but before reading the entire body the connection state could be corrupted due to the header being consumed. The next read would consider the body of the previous message as the header for the next. fixes #348 --- pgproto3/backend.go | 27 +++++++++++------ pgproto3/backend_test.go | 37 +++++++++++++++++++++++ pgproto3/frontend.go | 27 +++++++++++------ pgproto3/frontend_test.go | 62 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 18 deletions(-) create mode 100644 pgproto3/backend_test.go create mode 100644 pgproto3/frontend_test.go diff --git a/pgproto3/backend.go b/pgproto3/backend.go index 9a7ef342..8f3c3478 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -24,6 +24,10 @@ type Backend struct { startupMessage StartupMessage sync Sync terminate Terminate + + bodyLen int + msgType byte + partialMsg bool } func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { @@ -57,16 +61,19 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { } func (b *Backend) Receive() (FrontendMessage, error) { - header, err := b.cr.Next(5) - if err != nil { - return nil, err + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true } - msgType := header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - var msg FrontendMessage - switch msgType { + switch b.msgType { case 'B': msg = &b.bind case 'C': @@ -88,14 +95,16 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, errors.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", b.msgType) } - msgBody, err := b.cr.Next(bodyLen) + msgBody, err := b.cr.Next(b.bodyLen) if err != nil { return nil, err } + b.partialMsg = false + err = msg.Decode(msgBody) return msg, err } diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go new file mode 100644 index 00000000..02a5e9ca --- /dev/null +++ b/pgproto3/backend_test.go @@ -0,0 +1,37 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/pgproto3" +) + +func TestBackendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend, err := pgproto3.NewBackend(server, nil) + if err != nil { + t.Fatal(err) + } + + msg, err := backend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I', 0}) + + msg, err = backend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { + t.Fatalf("unexpected msg: %v", msg) + } +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index c8ab5f15..d803d362 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -34,6 +34,10 @@ type Frontend struct { parseComplete ParseComplete readyForQuery ReadyForQuery rowDescription RowDescription + + bodyLen int + msgType byte + partialMsg bool } func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { @@ -47,16 +51,19 @@ func (b *Frontend) Send(msg FrontendMessage) error { } func (b *Frontend) Receive() (BackendMessage, error) { - header, err := b.cr.Next(5) - if err != nil { - return nil, err + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true } - msgType := header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - var msg BackendMessage - switch msgType { + switch b.msgType { case '1': msg = &b.parseComplete case '2': @@ -100,14 +107,16 @@ func (b *Frontend) Receive() (BackendMessage, error) { case 'Z': msg = &b.readyForQuery default: - return nil, errors.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", b.msgType) } - msgBody, err := b.cr.Next(bodyLen) + msgBody, err := b.cr.Next(b.bodyLen) if err != nil { return nil, err } + b.partialMsg = false + err = msg.Decode(msgBody) return msg, err } diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go new file mode 100644 index 00000000..7d6652c1 --- /dev/null +++ b/pgproto3/frontend_test.go @@ -0,0 +1,62 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgproto3" +) + +type interruptReader struct { + chunks [][]byte +} + +func (ir *interruptReader) Read(p []byte) (n int, err error) { + if len(ir.chunks) == 0 { + return 0, errors.New("no data") + } + + n = copy(p, ir.chunks[0]) + if n != len(ir.chunks[0]) { + panic("this test reader doesn't support partial reads of chunks") + } + + ir.chunks = ir.chunks[1:] + + return n, nil +} + +func (ir *interruptReader) push(p []byte) { + ir.chunks = append(ir.chunks, p) +} + +func TestFrontendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend, err := pgproto3.NewFrontend(server, nil) + if err != nil { + t.Fatal(err) + } + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I'}) + + msg, err = frontend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { + t.Fatalf("unexpected msg: %v", msg) + } +}