From 4c24c635a9c2b9a787f76e7d9ab1151faed71a79 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Mon, 1 May 2017 18:11:55 -0500
Subject: [PATCH] Add pgproto3.Backend

---
 pgproto3/backend.go          |  74 ++++++++++++++++
 pgproto3/bind.go             | 167 +++++++++++++++++++++++++++++++++++
 pgproto3/describe.go         |  60 +++++++++++++
 pgproto3/execute.go          |  60 +++++++++++++
 pgproto3/parse.go            |  82 +++++++++++++++++
 pgproto3/password_message.go |  44 +++++++++
 pgproto3/sync.go             |  29 ++++++
 pgproto3/terminate.go        |  29 ++++++
 8 files changed, 545 insertions(+)
 create mode 100644 pgproto3/backend.go
 create mode 100644 pgproto3/bind.go
 create mode 100644 pgproto3/describe.go
 create mode 100644 pgproto3/execute.go
 create mode 100644 pgproto3/parse.go
 create mode 100644 pgproto3/password_message.go
 create mode 100644 pgproto3/sync.go
 create mode 100644 pgproto3/terminate.go

diff --git a/pgproto3/backend.go b/pgproto3/backend.go
new file mode 100644
index 00000000..c04116a8
--- /dev/null
+++ b/pgproto3/backend.go
@@ -0,0 +1,74 @@
+package pgproto3
+
+import (
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+
+	"github.com/jackc/pgx/chunkreader"
+)
+
+type Backend struct {
+	cr *chunkreader.ChunkReader
+	w  io.Writer
+
+	// Frontend message flyweights
+	bind            Bind
+	describe        Describe
+	execute         Execute
+	parse           Parse
+	passwordMessage PasswordMessage
+	query           Query
+	sync            Sync
+	terminate       Terminate
+}
+
+func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
+	cr := chunkreader.NewChunkReader(r)
+	return &Backend{cr: cr, w: w}, nil
+}
+
+func (b *Backend) Send(msg BackendMessage) error {
+	return errors.New("not implemented")
+}
+
+func (b *Backend) Receive() (FrontendMessage, error) {
+	header, err := b.cr.Next(5)
+	if err != nil {
+		return nil, err
+	}
+
+	msgType := header[0]
+	bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4
+
+	var msg FrontendMessage
+	switch msgType {
+	case 'B':
+		msg = &b.bind
+	case 'D':
+		msg = &b.describe
+	case 'E':
+		msg = &b.execute
+	case 'P':
+		msg = &b.parse
+	case 'p':
+		msg = &b.passwordMessage
+	case 'Q':
+		msg = &b.query
+	case 'S':
+		msg = &b.sync
+	case 'X':
+		msg = &b.terminate
+	default:
+		return nil, fmt.Errorf("unknown message type: %c", msgType)
+	}
+
+	msgBody, err := b.cr.Next(bodyLen)
+	if err != nil {
+		return nil, err
+	}
+
+	err = msg.Decode(msgBody)
+	return msg, err
+}
diff --git a/pgproto3/bind.go b/pgproto3/bind.go
new file mode 100644
index 00000000..6661a775
--- /dev/null
+++ b/pgproto3/bind.go
@@ -0,0 +1,167 @@
+package pgproto3
+
+import (
+	"bytes"
+	"encoding/binary"
+	"encoding/hex"
+	"encoding/json"
+)
+
+type Bind struct {
+	DestinationPortal    string
+	PreparedStatement    string
+	ParameterFormatCodes []int16
+	Parameters           [][]byte
+	ResultFormatCodes    []int16
+}
+
+func (*Bind) Frontend() {}
+
+func (dst *Bind) Decode(src []byte) error {
+	idx := bytes.IndexByte(src, 0)
+	if idx < 0 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	dst.DestinationPortal = string(src[:idx])
+	rp := idx + 1
+
+	idx = bytes.IndexByte(src[rp:], 0)
+	if idx < 0 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	dst.PreparedStatement = string(src[rp : rp+idx])
+	rp += idx + 1
+
+	if len(src[rp:]) < 2 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
+	rp += 2
+
+	dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
+
+	if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	for i := 0; i < parameterFormatCodeCount; i++ {
+		dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
+		rp += 2
+	}
+
+	if len(src[rp:]) < 2 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
+
+	dst.Parameters = make([][]byte, parameterCount)
+
+	for i := 0; i < parameterCount; i++ {
+		if len(src[rp:]) < 4 {
+			return &invalidMessageFormatErr{messageType: "Bind"}
+		}
+
+		msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
+		rp += 4
+
+		// null
+		if msgSize == -1 {
+			continue
+		}
+
+		if len(src[rp:]) < msgSize {
+			return &invalidMessageFormatErr{messageType: "Bind"}
+		}
+
+		dst.Parameters[i] = src[rp : rp+msgSize]
+		rp += msgSize
+	}
+
+	if len(src[rp:]) < 2 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
+	rp += 2
+
+	dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
+	if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
+		return &invalidMessageFormatErr{messageType: "Bind"}
+	}
+	for i := 0; i < resultFormatCodeCount; i++ {
+		dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
+		rp += 2
+	}
+
+	return nil
+}
+
+func (src *Bind) MarshalBinary() ([]byte, error) {
+	var bigEndian BigEndianBuf
+	buf := &bytes.Buffer{}
+
+	buf.WriteByte('B')
+	buf.Write(bigEndian.Uint32(0))
+
+	buf.WriteString(src.DestinationPortal)
+	buf.WriteByte(0)
+	buf.WriteString(src.PreparedStatement)
+	buf.WriteByte(0)
+
+	buf.Write(bigEndian.Uint16(uint16(len(src.ParameterFormatCodes))))
+
+	for _, fc := range src.ParameterFormatCodes {
+		buf.Write(bigEndian.Int16(fc))
+	}
+
+	buf.Write(bigEndian.Uint16(uint16(len(src.Parameters))))
+
+	for _, p := range src.Parameters {
+		if p == nil {
+			buf.Write(bigEndian.Int32(-1))
+			continue
+		}
+
+		buf.Write(bigEndian.Int32(int32(len(p))))
+		buf.Write(p)
+	}
+
+	buf.Write(bigEndian.Uint16(uint16(len(src.ResultFormatCodes))))
+
+	for _, fc := range src.ResultFormatCodes {
+		buf.Write(bigEndian.Int16(fc))
+	}
+
+	binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
+
+	return buf.Bytes(), nil
+}
+
+func (src *Bind) MarshalJSON() ([]byte, error) {
+	formattedParameters := make([]map[string]string, len(src.Parameters))
+	for i, p := range src.Parameters {
+		if p == nil {
+			continue
+		}
+
+		if src.ParameterFormatCodes[i] == 0 {
+			formattedParameters[i] = map[string]string{"text": string(p)}
+		} else {
+			formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
+		}
+	}
+
+	return json.Marshal(struct {
+		Type                 string
+		DestinationPortal    string
+		PreparedStatement    string
+		ParameterFormatCodes []int16
+		Parameters           []map[string]string
+		ResultFormatCodes    []int16
+	}{
+		Type:                 "Bind",
+		DestinationPortal:    src.DestinationPortal,
+		PreparedStatement:    src.PreparedStatement,
+		ParameterFormatCodes: src.ParameterFormatCodes,
+		Parameters:           formattedParameters,
+		ResultFormatCodes:    src.ResultFormatCodes,
+	})
+}
diff --git a/pgproto3/describe.go b/pgproto3/describe.go
new file mode 100644
index 00000000..ea55ed9d
--- /dev/null
+++ b/pgproto3/describe.go
@@ -0,0 +1,60 @@
+package pgproto3
+
+import (
+	"bytes"
+	"encoding/binary"
+	"encoding/json"
+)
+
+type Describe struct {
+	ObjectType byte // 'S' = prepared statement, 'P' = portal
+	Name       string
+}
+
+func (*Describe) Frontend() {}
+
+func (dst *Describe) Decode(src []byte) error {
+	if len(src) < 2 {
+		return &invalidMessageFormatErr{messageType: "Describe"}
+	}
+
+	dst.ObjectType = src[0]
+	rp := 1
+
+	idx := bytes.IndexByte(src[rp:], 0)
+	if idx != len(src[rp:])-1 {
+		return &invalidMessageFormatErr{messageType: "Describe"}
+	}
+
+	dst.Name = string(src[rp : len(src)-1])
+
+	return nil
+}
+
+func (src *Describe) MarshalBinary() ([]byte, error) {
+	var bigEndian BigEndianBuf
+	buf := &bytes.Buffer{}
+
+	buf.WriteByte('D')
+	buf.Write(bigEndian.Uint32(0))
+
+	buf.WriteByte(src.ObjectType)
+	buf.WriteString(src.Name)
+	buf.WriteByte(0)
+
+	binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
+
+	return buf.Bytes(), nil
+}
+
+func (src *Describe) MarshalJSON() ([]byte, error) {
+	return json.Marshal(struct {
+		Type       string
+		ObjectType string
+		Name       string
+	}{
+		Type:       "Describe",
+		ObjectType: string(src.ObjectType),
+		Name:       src.Name,
+	})
+}
diff --git a/pgproto3/execute.go b/pgproto3/execute.go
new file mode 100644
index 00000000..4892e7b3
--- /dev/null
+++ b/pgproto3/execute.go
@@ -0,0 +1,60 @@
+package pgproto3
+
+import (
+	"bytes"
+	"encoding/binary"
+	"encoding/json"
+)
+
+type Execute struct {
+	Portal  string
+	MaxRows uint32
+}
+
+func (*Execute) Frontend() {}
+
+func (dst *Execute) Decode(src []byte) error {
+	buf := bytes.NewBuffer(src)
+
+	b, err := buf.ReadBytes(0)
+	if err != nil {
+		return err
+	}
+	dst.Portal = string(b[:len(b)-1])
+
+	if buf.Len() < 4 {
+		return &invalidMessageFormatErr{messageType: "Execute"}
+	}
+	dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
+
+	return nil
+}
+
+func (src *Execute) MarshalBinary() ([]byte, error) {
+	var bigEndian BigEndianBuf
+	buf := &bytes.Buffer{}
+
+	buf.WriteByte('E')
+	buf.Write(bigEndian.Uint32(0))
+
+	buf.WriteString(src.Portal)
+	buf.WriteByte(0)
+
+	buf.Write(bigEndian.Uint32(src.MaxRows))
+
+	binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
+
+	return buf.Bytes(), nil
+}
+
+func (src *Execute) MarshalJSON() ([]byte, error) {
+	return json.Marshal(struct {
+		Type    string
+		Portal  string
+		MaxRows uint32
+	}{
+		Type:    "Execute",
+		Portal:  src.Portal,
+		MaxRows: src.MaxRows,
+	})
+}
diff --git a/pgproto3/parse.go b/pgproto3/parse.go
new file mode 100644
index 00000000..5d17ed11
--- /dev/null
+++ b/pgproto3/parse.go
@@ -0,0 +1,82 @@
+package pgproto3
+
+import (
+	"bytes"
+	"encoding/binary"
+	"encoding/json"
+)
+
+type Parse struct {
+	Name          string
+	Query         string
+	ParameterOIDs []uint32
+}
+
+func (*Parse) Frontend() {}
+
+func (dst *Parse) Decode(src []byte) error {
+	buf := bytes.NewBuffer(src)
+
+	b, err := buf.ReadBytes(0)
+	if err != nil {
+		return err
+	}
+	dst.Name = string(b[:len(b)-1])
+
+	b, err = buf.ReadBytes(0)
+	if err != nil {
+		return err
+	}
+	dst.Query = string(b[:len(b)-1])
+
+	if buf.Len() < 2 {
+		return &invalidMessageFormatErr{messageType: "Parse"}
+	}
+	parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
+
+	for i := 0; i < parameterOIDCount; i++ {
+		if buf.Len() < 4 {
+			return &invalidMessageFormatErr{messageType: "Parse"}
+		}
+		dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
+	}
+
+	return nil
+}
+
+func (src *Parse) MarshalBinary() ([]byte, error) {
+	var bigEndian BigEndianBuf
+	buf := &bytes.Buffer{}
+
+	buf.WriteByte('P')
+	buf.Write(bigEndian.Uint32(0))
+
+	buf.WriteString(src.Name)
+	buf.WriteByte(0)
+	buf.WriteString(src.Query)
+	buf.WriteByte(0)
+
+	buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs))))
+
+	for _, v := range src.ParameterOIDs {
+		buf.Write(bigEndian.Uint32(v))
+	}
+
+	binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
+
+	return buf.Bytes(), nil
+}
+
+func (src *Parse) MarshalJSON() ([]byte, error) {
+	return json.Marshal(struct {
+		Type          string
+		Name          string
+		Query         string
+		ParameterOIDs []uint32
+	}{
+		Type:          "Parse",
+		Name:          src.Name,
+		Query:         src.Query,
+		ParameterOIDs: src.ParameterOIDs,
+	})
+}
diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go
new file mode 100644
index 00000000..69df6362
--- /dev/null
+++ b/pgproto3/password_message.go
@@ -0,0 +1,44 @@
+package pgproto3
+
+import (
+	"bytes"
+	"encoding/json"
+)
+
+type PasswordMessage struct {
+	Password string
+}
+
+func (*PasswordMessage) Frontend() {}
+
+func (dst *PasswordMessage) Decode(src []byte) error {
+	buf := bytes.NewBuffer(src)
+
+	b, err := buf.ReadBytes(0)
+	if err != nil {
+		return err
+	}
+	dst.Password = string(b[:len(b)-1])
+
+	return nil
+}
+
+func (src *PasswordMessage) MarshalBinary() ([]byte, error) {
+	var bigEndian BigEndianBuf
+	buf := &bytes.Buffer{}
+	buf.WriteByte('p')
+	buf.Write(bigEndian.Uint32(uint32(4 + len(src.Password) + 1)))
+	buf.WriteString(src.Password)
+	buf.WriteByte(0)
+	return buf.Bytes(), nil
+}
+
+func (src *PasswordMessage) MarshalJSON() ([]byte, error) {
+	return json.Marshal(struct {
+		Type     string
+		Password string
+	}{
+		Type:     "PasswordMessage",
+		Password: src.Password,
+	})
+}
diff --git a/pgproto3/sync.go b/pgproto3/sync.go
new file mode 100644
index 00000000..da3fa727
--- /dev/null
+++ b/pgproto3/sync.go
@@ -0,0 +1,29 @@
+package pgproto3
+
+import (
+	"encoding/json"
+)
+
+type Sync struct{}
+
+func (*Sync) Frontend() {}
+
+func (dst *Sync) Decode(src []byte) error {
+	if len(src) != 0 {
+		return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
+	}
+
+	return nil
+}
+
+func (src *Sync) MarshalBinary() ([]byte, error) {
+	return []byte{'S', 0, 0, 0, 4}, nil
+}
+
+func (src *Sync) MarshalJSON() ([]byte, error) {
+	return json.Marshal(struct {
+		Type string
+	}{
+		Type: "Sync",
+	})
+}
diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go
new file mode 100644
index 00000000..77977f20
--- /dev/null
+++ b/pgproto3/terminate.go
@@ -0,0 +1,29 @@
+package pgproto3
+
+import (
+	"encoding/json"
+)
+
+type Terminate struct{}
+
+func (*Terminate) Frontend() {}
+
+func (dst *Terminate) Decode(src []byte) error {
+	if len(src) != 0 {
+		return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
+	}
+
+	return nil
+}
+
+func (src *Terminate) MarshalBinary() ([]byte, error) {
+	return []byte{'X', 0, 0, 0, 4}, nil
+}
+
+func (src *Terminate) MarshalJSON() ([]byte, error) {
+	return json.Marshal(struct {
+		Type string
+	}{
+		Type: "Terminate",
+	})
+}