diff --git a/pgproto3/backend.go b/pgproto3/backend.go new file mode 100644 index 00000000..c04116a8 --- /dev/null +++ b/pgproto3/backend.go @@ -0,0 +1,74 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Backend struct { + cr *chunkreader.ChunkReader + w io.Writer + + // Frontend message flyweights + bind Bind + describe Describe + execute Execute + parse Parse + passwordMessage PasswordMessage + query Query + sync Sync + terminate Terminate +} + +func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { + cr := chunkreader.NewChunkReader(r) + return &Backend{cr: cr, w: w}, nil +} + +func (b *Backend) Send(msg BackendMessage) error { + return errors.New("not implemented") +} + +func (b *Backend) Receive() (FrontendMessage, error) { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + var msg FrontendMessage + switch msgType { + case 'B': + msg = &b.bind + case 'D': + msg = &b.describe + case 'E': + msg = &b.execute + case 'P': + msg = &b.parse + case 'p': + msg = &b.passwordMessage + case 'Q': + msg = &b.query + case 'S': + msg = &b.sync + case 'X': + msg = &b.terminate + default: + return nil, fmt.Errorf("unknown message type: %c", msgType) + } + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + err = msg.Decode(msgBody) + return msg, err +} diff --git a/pgproto3/bind.go b/pgproto3/bind.go new file mode 100644 index 00000000..6661a775 --- /dev/null +++ b/pgproto3/bind.go @@ -0,0 +1,167 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type Bind struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters [][]byte + ResultFormatCodes []int16 +} + +func (*Bind) Frontend() {} + +func (dst *Bind) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + + dst.Parameters = make([][]byte, parameterCount) + + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) + if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < resultFormatCodeCount; i++ { + dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return nil +} + +func (src *Bind) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('B') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.DestinationPortal) + buf.WriteByte(0) + buf.WriteString(src.PreparedStatement) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterFormatCodes)))) + + for _, fc := range src.ParameterFormatCodes { + buf.Write(bigEndian.Int16(fc)) + } + + buf.Write(bigEndian.Uint16(uint16(len(src.Parameters)))) + + for _, p := range src.Parameters { + if p == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(p)))) + buf.Write(p) + } + + buf.Write(bigEndian.Uint16(uint16(len(src.ResultFormatCodes)))) + + for _, fc := range src.ResultFormatCodes { + buf.Write(bigEndian.Int16(fc)) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Bind) MarshalJSON() ([]byte, error) { + formattedParameters := make([]map[string]string, len(src.Parameters)) + for i, p := range src.Parameters { + if p == nil { + continue + } + + if src.ParameterFormatCodes[i] == 0 { + formattedParameters[i] = map[string]string{"text": string(p)} + } else { + formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} + } + } + + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + }{ + Type: "Bind", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + ParameterFormatCodes: src.ParameterFormatCodes, + Parameters: formattedParameters, + ResultFormatCodes: src.ResultFormatCodes, + }) +} diff --git a/pgproto3/describe.go b/pgproto3/describe.go new file mode 100644 index 00000000..ea55ed9d --- /dev/null +++ b/pgproto3/describe.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Describe struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Describe) Frontend() {} + +func (dst *Describe) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Describe) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteByte(src.ObjectType) + buf.WriteString(src.Name) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Describe) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Describe", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go new file mode 100644 index 00000000..4892e7b3 --- /dev/null +++ b/pgproto3/execute.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Execute struct { + Portal string + MaxRows uint32 +} + +func (*Execute) Frontend() {} + +func (dst *Execute) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Portal = string(b[:len(b)-1]) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Execute"} + } + dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) + + return nil +} + +func (src *Execute) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('E') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Portal) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(src.MaxRows)) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Execute) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Portal string + MaxRows uint32 + }{ + Type: "Execute", + Portal: src.Portal, + MaxRows: src.MaxRows, + }) +} diff --git a/pgproto3/parse.go b/pgproto3/parse.go new file mode 100644 index 00000000..5d17ed11 --- /dev/null +++ b/pgproto3/parse.go @@ -0,0 +1,82 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Parse struct { + Name string + Query string + ParameterOIDs []uint32 +} + +func (*Parse) Frontend() {} + +func (dst *Parse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.Query = string(b[:len(b)-1]) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + for i := 0; i < parameterOIDCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +func (src *Parse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('P') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Query) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, v := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(v)) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Parse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Query string + ParameterOIDs []uint32 + }{ + Type: "Parse", + Name: src.Name, + Query: src.Query, + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go new file mode 100644 index 00000000..69df6362 --- /dev/null +++ b/pgproto3/password_message.go @@ -0,0 +1,44 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type PasswordMessage struct { + Password string +} + +func (*PasswordMessage) Frontend() {} + +func (dst *PasswordMessage) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Password = string(b[:len(b)-1]) + + return nil +} + +func (src *PasswordMessage) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('p') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Password) + 1))) + buf.WriteString(src.Password) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *PasswordMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Password string + }{ + Type: "PasswordMessage", + Password: src.Password, + }) +} diff --git a/pgproto3/sync.go b/pgproto3/sync.go new file mode 100644 index 00000000..da3fa727 --- /dev/null +++ b/pgproto3/sync.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Sync struct{} + +func (*Sync) Frontend() {} + +func (dst *Sync) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Sync) MarshalBinary() ([]byte, error) { + return []byte{'S', 0, 0, 0, 4}, nil +} + +func (src *Sync) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Sync", + }) +} diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go new file mode 100644 index 00000000..77977f20 --- /dev/null +++ b/pgproto3/terminate.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Terminate struct{} + +func (*Terminate) Frontend() {} + +func (dst *Terminate) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Terminate) MarshalBinary() ([]byte, error) { + return []byte{'X', 0, 0, 0, 4}, nil +} + +func (src *Terminate) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Terminate", + }) +}