mirror of https://github.com/jackc/pgx.git
Fix reading interrupted messages
When an message is received and a timeout occurs after reading the header but before reading the entire body the connection state could be corrupted due to the header being consumed. The next read would consider the body of the previous message as the header for the next. fixes #348pull/359/merge
parent
1ed4024c70
commit
cbb3fa5ecc
|
@ -24,6 +24,10 @@ type Backend struct {
|
||||||
startupMessage StartupMessage
|
startupMessage StartupMessage
|
||||||
sync Sync
|
sync Sync
|
||||||
terminate Terminate
|
terminate Terminate
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
||||||
|
@ -57,16 +61,19 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
header, err := b.cr.Next(5)
|
if !b.partialMsg {
|
||||||
if err != nil {
|
header, err := b.cr.Next(5)
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.msgType = header[0]
|
||||||
|
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
b.partialMsg = true
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType := header[0]
|
|
||||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
|
||||||
|
|
||||||
var msg FrontendMessage
|
var msg FrontendMessage
|
||||||
switch msgType {
|
switch b.msgType {
|
||||||
case 'B':
|
case 'B':
|
||||||
msg = &b.bind
|
msg = &b.bind
|
||||||
case 'C':
|
case 'C':
|
||||||
|
@ -88,14 +95,16 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
case 'X':
|
case 'X':
|
||||||
msg = &b.terminate
|
msg = &b.terminate
|
||||||
default:
|
default:
|
||||||
return nil, errors.Errorf("unknown message type: %c", msgType)
|
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgBody, err := b.cr.Next(bodyLen)
|
msgBody, err := b.cr.Next(b.bodyLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b.partialMsg = false
|
||||||
|
|
||||||
err = msg.Decode(msgBody)
|
err = msg.Decode(msgBody)
|
||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgproto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||||
|
|
||||||
|
backend, err := pgproto3.NewBackend(server, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected err")
|
||||||
|
}
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("did not expect msg, but %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.push([]byte{'I', 0})
|
||||||
|
|
||||||
|
msg, err = backend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
|
||||||
|
t.Fatalf("unexpected msg: %v", msg)
|
||||||
|
}
|
||||||
|
}
|
|
@ -34,6 +34,10 @@ type Frontend struct {
|
||||||
parseComplete ParseComplete
|
parseComplete ParseComplete
|
||||||
readyForQuery ReadyForQuery
|
readyForQuery ReadyForQuery
|
||||||
rowDescription RowDescription
|
rowDescription RowDescription
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
||||||
|
@ -47,16 +51,19 @@ func (b *Frontend) Send(msg FrontendMessage) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||||
header, err := b.cr.Next(5)
|
if !b.partialMsg {
|
||||||
if err != nil {
|
header, err := b.cr.Next(5)
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.msgType = header[0]
|
||||||
|
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
b.partialMsg = true
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType := header[0]
|
|
||||||
bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
|
|
||||||
|
|
||||||
var msg BackendMessage
|
var msg BackendMessage
|
||||||
switch msgType {
|
switch b.msgType {
|
||||||
case '1':
|
case '1':
|
||||||
msg = &b.parseComplete
|
msg = &b.parseComplete
|
||||||
case '2':
|
case '2':
|
||||||
|
@ -100,14 +107,16 @@ func (b *Frontend) Receive() (BackendMessage, error) {
|
||||||
case 'Z':
|
case 'Z':
|
||||||
msg = &b.readyForQuery
|
msg = &b.readyForQuery
|
||||||
default:
|
default:
|
||||||
return nil, errors.Errorf("unknown message type: %c", msgType)
|
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgBody, err := b.cr.Next(bodyLen)
|
msgBody, err := b.cr.Next(b.bodyLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b.partialMsg = false
|
||||||
|
|
||||||
err = msg.Decode(msgBody)
|
err = msg.Decode(msgBody)
|
||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgproto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type interruptReader struct {
|
||||||
|
chunks [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ir *interruptReader) Read(p []byte) (n int, err error) {
|
||||||
|
if len(ir.chunks) == 0 {
|
||||||
|
return 0, errors.New("no data")
|
||||||
|
}
|
||||||
|
|
||||||
|
n = copy(p, ir.chunks[0])
|
||||||
|
if n != len(ir.chunks[0]) {
|
||||||
|
panic("this test reader doesn't support partial reads of chunks")
|
||||||
|
}
|
||||||
|
|
||||||
|
ir.chunks = ir.chunks[1:]
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ir *interruptReader) push(p []byte) {
|
||||||
|
ir.chunks = append(ir.chunks, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendReceiveInterrupted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Z', 0, 0, 0, 5})
|
||||||
|
|
||||||
|
frontend, err := pgproto3.NewFrontend(server, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := frontend.Receive()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected err")
|
||||||
|
}
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("did not expect msg, but %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.push([]byte{'I'})
|
||||||
|
|
||||||
|
msg, err = frontend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
|
||||||
|
t.Fatalf("unexpected msg: %v", msg)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue