diff --git a/pgio/read.go b/pgio/read.go index 7ddad508..033bada4 100644 --- a/pgio/read.go +++ b/pgio/read.go @@ -1,6 +1,7 @@ package pgio import ( + "bytes" "encoding/binary" ) @@ -38,3 +39,13 @@ func NextInt64(buf []byte) ([]byte, int64) { buf, n := NextUint64(buf) return buf, int64(n) } + +func NextCString(buf []byte) ([]byte, string, bool) { + idx := bytes.IndexByte(buf, 0) + if idx < 0 { + return buf, "", false + } + cstring := string(buf[:idx]) + buf = buf[:idx+1] + return buf, cstring, true +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go index c04116a8..bd477315 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "errors" "fmt" "io" @@ -20,6 +19,7 @@ type Backend struct { parse Parse passwordMessage PasswordMessage query Query + startupMessage StartupMessage sync Sync terminate Terminate } @@ -30,7 +30,33 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { } func (b *Backend) Send(msg BackendMessage) error { - return errors.New("not implemented") + buf, err := msg.MarshalBinary() + if err != nil { + return nil + } + + _, err = b.w.Write(buf) + return err +} + +func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { + buf, err := b.cr.Next(4) + if err != nil { + return nil, err + } + msgSize := int(binary.BigEndian.Uint32(buf) - 4) + + buf, err = b.cr.Next(msgSize) + if err != nil { + return nil, err + } + + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + + return &b.startupMessage, nil } func (b *Backend) Receive() (FrontendMessage, error) { diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 50835836..27a9890a 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "errors" "fmt" "io" @@ -43,7 +42,13 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { } func (b *Frontend) Send(msg FrontendMessage) error { - return errors.New("not implemented") + buf, err := msg.MarshalBinary() + if err != nil { + return nil + } + + _, err = b.w.Write(buf) + return err } func (b *Frontend) Receive() (BackendMessage, error) { diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go new file mode 100644 index 00000000..ebb804fe --- /dev/null +++ b/pgproto3/startup_message.go @@ -0,0 +1,95 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" +) + +const ( + protocolVersionNumber = 196608 // 3.0 + sslRequestNumber = 80877103 +) + +type StartupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +func (*StartupMessage) Frontend() {} + +func (dst *StartupMessage) Decode(src []byte) error { + if len(src) < 4 { + return fmt.Errorf("startup message too short") + } + + dst.ProtocolVersion = binary.BigEndian.Uint32(src) + rp := 4 + + if dst.ProtocolVersion == sslRequestNumber { + return fmt.Errorf("can't handle ssl connection request") + } + + if dst.ProtocolVersion != protocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion) + } + + dst.Parameters = make(map[string]string) + for { + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + key := string(src[rp : rp+idx]) + rp += idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + value := string(src[rp : rp+idx]) + rp += idx + 1 + + dst.Parameters[key] = value + + if len(src[rp:]) == 1 { + if src[rp] != 0 { + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + } + break + } + } + + return nil +} + +func (src *StartupMessage) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.ProtocolVersion)) + for k, v := range src.Parameters { + buf.WriteString(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len())) + + return buf.Bytes(), nil +} + +func (src *StartupMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "StartupMessage", + ProtocolVersion: src.ProtocolVersion, + Parameters: src.Parameters, + }) +}