diff --git a/pgproto3/backend.go b/pgproto3/backend.go index 9a7ef342..8f3c3478 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -24,6 +24,10 @@ type Backend struct { startupMessage StartupMessage sync Sync terminate Terminate + + bodyLen int + msgType byte + partialMsg bool } func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { @@ -57,16 +61,19 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { } func (b *Backend) Receive() (FrontendMessage, error) { - header, err := b.cr.Next(5) - if err != nil { - return nil, err + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true } - msgType := header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - var msg FrontendMessage - switch msgType { + switch b.msgType { case 'B': msg = &b.bind case 'C': @@ -88,14 +95,16 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, errors.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", b.msgType) } - msgBody, err := b.cr.Next(bodyLen) + msgBody, err := b.cr.Next(b.bodyLen) if err != nil { return nil, err } + b.partialMsg = false + err = msg.Decode(msgBody) return msg, err } diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go new file mode 100644 index 00000000..02a5e9ca --- /dev/null +++ b/pgproto3/backend_test.go @@ -0,0 +1,37 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/pgproto3" +) + +func TestBackendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend, err := pgproto3.NewBackend(server, nil) + if err != nil { + t.Fatal(err) + } + + msg, err := backend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I', 0}) + + msg, err = backend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { + t.Fatalf("unexpected msg: %v", msg) + } +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index c8ab5f15..d803d362 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -34,6 +34,10 @@ type Frontend struct { parseComplete ParseComplete readyForQuery ReadyForQuery rowDescription RowDescription + + bodyLen int + msgType byte + partialMsg bool } func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { @@ -47,16 +51,19 @@ func (b *Frontend) Send(msg FrontendMessage) error { } func (b *Frontend) Receive() (BackendMessage, error) { - header, err := b.cr.Next(5) - if err != nil { - return nil, err + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true } - msgType := header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - var msg BackendMessage - switch msgType { + switch b.msgType { case '1': msg = &b.parseComplete case '2': @@ -100,14 +107,16 @@ func (b *Frontend) Receive() (BackendMessage, error) { case 'Z': msg = &b.readyForQuery default: - return nil, errors.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", b.msgType) } - msgBody, err := b.cr.Next(bodyLen) + msgBody, err := b.cr.Next(b.bodyLen) if err != nil { return nil, err } + b.partialMsg = false + err = msg.Decode(msgBody) return msg, err } diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go new file mode 100644 index 00000000..7d6652c1 --- /dev/null +++ b/pgproto3/frontend_test.go @@ -0,0 +1,62 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgproto3" +) + +type interruptReader struct { + chunks [][]byte +} + +func (ir *interruptReader) Read(p []byte) (n int, err error) { + if len(ir.chunks) == 0 { + return 0, errors.New("no data") + } + + n = copy(p, ir.chunks[0]) + if n != len(ir.chunks[0]) { + panic("this test reader doesn't support partial reads of chunks") + } + + ir.chunks = ir.chunks[1:] + + return n, nil +} + +func (ir *interruptReader) push(p []byte) { + ir.chunks = append(ir.chunks, p) +} + +func TestFrontendReceiveInterrupted(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Z', 0, 0, 0, 5}) + + frontend, err := pgproto3.NewFrontend(server, nil) + if err != nil { + t.Fatal(err) + } + + msg, err := frontend.Receive() + if err == nil { + t.Fatal("expected err") + } + if msg != nil { + t.Fatalf("did not expect msg, but %v", msg) + } + + server.push([]byte{'I'}) + + msg, err = frontend.Receive() + if err != nil { + t.Fatal(err) + } + if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { + t.Fatalf("unexpected msg: %v", msg) + } +}