mirror of https://github.com/jackc/pgx.git
Update pgproto3 to enable pgmock
parent
458dd24a9f
commit
b1489a1eab
11
pgio/read.go
11
pgio/read.go
|
@ -1,6 +1,7 @@
|
|||
package pgio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
|
@ -38,3 +39,13 @@ func NextInt64(buf []byte) ([]byte, int64) {
|
|||
buf, n := NextUint64(buf)
|
||||
return buf, int64(n)
|
||||
}
|
||||
|
||||
func NextCString(buf []byte) ([]byte, string, bool) {
|
||||
idx := bytes.IndexByte(buf, 0)
|
||||
if idx < 0 {
|
||||
return buf, "", false
|
||||
}
|
||||
cstring := string(buf[:idx])
|
||||
buf = buf[:idx+1]
|
||||
return buf, cstring, true
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package pgproto3
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
|
@ -20,6 +19,7 @@ type Backend struct {
|
|||
parse Parse
|
||||
passwordMessage PasswordMessage
|
||||
query Query
|
||||
startupMessage StartupMessage
|
||||
sync Sync
|
||||
terminate Terminate
|
||||
}
|
||||
|
@ -30,7 +30,33 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
|||
}
|
||||
|
||||
func (b *Backend) Send(msg BackendMessage) error {
|
||||
return errors.New("not implemented")
|
||||
buf, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = b.w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||
buf, err := b.cr.Next(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
|
||||
buf, err = b.cr.Next(msgSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &b.startupMessage, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
|
|
|
@ -2,7 +2,6 @@ package pgproto3
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
|
@ -43,7 +42,13 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
|||
}
|
||||
|
||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
||||
return errors.New("not implemented")
|
||||
buf, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = b.w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersionNumber = 196608 // 3.0
|
||||
sslRequestNumber = 80877103
|
||||
)
|
||||
|
||||
type StartupMessage struct {
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}
|
||||
|
||||
func (*StartupMessage) Frontend() {}
|
||||
|
||||
func (dst *StartupMessage) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return fmt.Errorf("startup message too short")
|
||||
}
|
||||
|
||||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||
rp := 4
|
||||
|
||||
if dst.ProtocolVersion == sslRequestNumber {
|
||||
return fmt.Errorf("can't handle ssl connection request")
|
||||
}
|
||||
|
||||
if dst.ProtocolVersion != protocolVersionNumber {
|
||||
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion)
|
||||
}
|
||||
|
||||
dst.Parameters = make(map[string]string)
|
||||
for {
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||
}
|
||||
key := string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||
}
|
||||
value := string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
dst.Parameters[key] = value
|
||||
|
||||
if len(src[rp:]) == 1 {
|
||||
if src[rp] != 0 {
|
||||
return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *StartupMessage) MarshalBinary() ([]byte, error) {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
buf.Write(bigEndian.Uint32(src.ProtocolVersion))
|
||||
for k, v := range src.Parameters {
|
||||
buf.WriteString(k)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
buf.WriteByte(0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len()))
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}{
|
||||
Type: "StartupMessage",
|
||||
ProtocolVersion: src.ProtocolVersion,
|
||||
Parameters: src.Parameters,
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue