Update pgproto3 to enable pgmock

batch-wip
Jack Christensen 2017-05-06 08:48:40 -05:00
parent 458dd24a9f
commit b1489a1eab
4 changed files with 141 additions and 4 deletions

View File

@ -1,6 +1,7 @@
package pgio
import (
"bytes"
"encoding/binary"
)
@ -38,3 +39,13 @@ func NextInt64(buf []byte) ([]byte, int64) {
buf, n := NextUint64(buf)
return buf, int64(n)
}
func NextCString(buf []byte) ([]byte, string, bool) {
idx := bytes.IndexByte(buf, 0)
if idx < 0 {
return buf, "", false
}
cstring := string(buf[:idx])
buf = buf[:idx+1]
return buf, cstring, true
}

View File

@ -2,7 +2,6 @@ package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
@ -20,6 +19,7 @@ type Backend struct {
parse Parse
passwordMessage PasswordMessage
query Query
startupMessage StartupMessage
sync Sync
terminate Terminate
}
@ -30,7 +30,33 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
}
func (b *Backend) Send(msg BackendMessage) error {
return errors.New("not implemented")
buf, err := msg.MarshalBinary()
if err != nil {
return nil
}
_, err = b.w.Write(buf)
return err
}
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, err
}
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
}
func (b *Backend) Receive() (FrontendMessage, error) {

View File

@ -2,7 +2,6 @@ package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
@ -43,7 +42,13 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
}
func (b *Frontend) Send(msg FrontendMessage) error {
return errors.New("not implemented")
buf, err := msg.MarshalBinary()
if err != nil {
return nil
}
_, err = b.w.Write(buf)
return err
}
func (b *Frontend) Receive() (BackendMessage, error) {

View File

@ -0,0 +1,95 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
)
const (
protocolVersionNumber = 196608 // 3.0
sslRequestNumber = 80877103
)
type StartupMessage struct {
ProtocolVersion uint32
Parameters map[string]string
}
func (*StartupMessage) Frontend() {}
func (dst *StartupMessage) Decode(src []byte) error {
if len(src) < 4 {
return fmt.Errorf("startup message too short")
}
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
rp := 4
if dst.ProtocolVersion == sslRequestNumber {
return fmt.Errorf("can't handle ssl connection request")
}
if dst.ProtocolVersion != protocolVersionNumber {
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion)
}
dst.Parameters = make(map[string]string)
for {
idx := bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
}
key := string(src[rp : rp+idx])
rp += idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
}
value := string(src[rp : rp+idx])
rp += idx + 1
dst.Parameters[key] = value
if len(src[rp:]) == 1 {
if src[rp] != 0 {
return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
}
break
}
}
return nil
}
func (src *StartupMessage) MarshalBinary() ([]byte, error) {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.Write(bigEndian.Uint32(0))
buf.Write(bigEndian.Uint32(src.ProtocolVersion))
for k, v := range src.Parameters {
buf.WriteString(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len()))
return buf.Bytes(), nil
}
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "StartupMessage",
ProtocolVersion: src.ProtocolVersion,
Parameters: src.Parameters,
})
}