mirror of https://github.com/jackc/pgx.git
Support SSLRequest and CancelRequest
parent
76538434cf
commit
1ba5dcbe01
30
backend.go
30
backend.go
|
@ -2,6 +2,7 @@ package pgproto3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -14,6 +15,7 @@ type Backend struct {
|
||||||
|
|
||||||
// Frontend message flyweights
|
// Frontend message flyweights
|
||||||
bind Bind
|
bind Bind
|
||||||
|
cancelRequest CancelRequest
|
||||||
_close Close
|
_close Close
|
||||||
copyFail CopyFail
|
copyFail CopyFail
|
||||||
describe Describe
|
describe Describe
|
||||||
|
@ -22,6 +24,7 @@ type Backend struct {
|
||||||
parse Parse
|
parse Parse
|
||||||
passwordMessage PasswordMessage
|
passwordMessage PasswordMessage
|
||||||
query Query
|
query Query
|
||||||
|
sslRequest SSLRequest
|
||||||
startupMessage StartupMessage
|
startupMessage StartupMessage
|
||||||
sync Sync
|
sync Sync
|
||||||
terminate Terminate
|
terminate Terminate
|
||||||
|
@ -42,9 +45,10 @@ func (b *Backend) Send(msg BackendMessage) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReceiveStartupMessage receives the initial startup message. This method is used of the normal Receive method
|
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||||
// because StartupMessage and SSLRequest are "special" and do not include the message type as the first byte.
|
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||||
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
// will return either a StartupMessage, SSLRequest, or CancelRequest.
|
||||||
|
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||||
buf, err := b.cr.Next(4)
|
buf, err := b.cr.Next(4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -56,12 +60,30 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
code := binary.BigEndian.Uint32(buf)
|
||||||
|
|
||||||
|
switch code {
|
||||||
|
case ProtocolVersionNumber:
|
||||||
err = b.startupMessage.Decode(buf)
|
err = b.startupMessage.Decode(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &b.startupMessage, nil
|
return &b.startupMessage, nil
|
||||||
|
case sslRequestNumber:
|
||||||
|
err = b.sslRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.sslRequest, nil
|
||||||
|
case cancelRequestCode:
|
||||||
|
err = b.cancelRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.cancelRequest, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive receives a message from the frontend.
|
// Receive receives a message from the frontend.
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const cancelRequestCode = 80877102
|
||||||
|
|
||||||
|
type CancelRequest struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CancelRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *CancelRequest) Decode(src []byte) error {
|
||||||
|
if len(src) != 12 {
|
||||||
|
return errors.Errorf("bad cancel request size")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != cancelRequestCode {
|
||||||
|
return errors.Errorf("bad cancel request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 16)
|
||||||
|
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "CancelRequest",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sslRequestNumber = 80877103
|
||||||
|
|
||||||
|
type SSLRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SSLRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *SSLRequest) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.Errorf("ssl request too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != sslRequestNumber {
|
||||||
|
return errors.Errorf("bad ssl request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *SSLRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SSLRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "SSLRequest",
|
||||||
|
})
|
||||||
|
}
|
|
@ -9,10 +9,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const ProtocolVersionNumber = 196608 // 3.0
|
||||||
ProtocolVersionNumber = 196608 // 3.0
|
|
||||||
sslRequestNumber = 80877103
|
|
||||||
)
|
|
||||||
|
|
||||||
type StartupMessage struct {
|
type StartupMessage struct {
|
||||||
ProtocolVersion uint32
|
ProtocolVersion uint32
|
||||||
|
@ -32,10 +29,6 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
||||||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||||
rp := 4
|
rp := 4
|
||||||
|
|
||||||
if dst.ProtocolVersion == sslRequestNumber {
|
|
||||||
return errors.Errorf("can't handle ssl connection request")
|
|
||||||
}
|
|
||||||
|
|
||||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||||
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue