mirror of https://github.com/jackc/pgx.git
Support SSLRequest and CancelRequest
parent
76538434cf
commit
1ba5dcbe01
38
backend.go
38
backend.go
|
@ -2,6 +2,7 @@ package pgproto3
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -14,6 +15,7 @@ type Backend struct {
|
|||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
cancelRequest CancelRequest
|
||||
_close Close
|
||||
copyFail CopyFail
|
||||
describe Describe
|
||||
|
@ -22,6 +24,7 @@ type Backend struct {
|
|||
parse Parse
|
||||
passwordMessage PasswordMessage
|
||||
query Query
|
||||
sslRequest SSLRequest
|
||||
startupMessage StartupMessage
|
||||
sync Sync
|
||||
terminate Terminate
|
||||
|
@ -42,9 +45,10 @@ func (b *Backend) Send(msg BackendMessage) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// ReceiveStartupMessage receives the initial startup message. This method is used of the normal Receive method
|
||||
// because StartupMessage and SSLRequest are "special" and do not include the message type as the first byte.
|
||||
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||
// will return either a StartupMessage, SSLRequest, or CancelRequest.
|
||||
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||
buf, err := b.cr.Next(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -56,12 +60,30 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code := binary.BigEndian.Uint32(buf)
|
||||
|
||||
return &b.startupMessage, nil
|
||||
switch code {
|
||||
case ProtocolVersionNumber:
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.startupMessage, nil
|
||||
case sslRequestNumber:
|
||||
err = b.sslRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.sslRequest, nil
|
||||
case cancelRequestCode:
|
||||
err = b.cancelRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.cancelRequest, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
// Receive receives a message from the frontend.
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const cancelRequestCode = 80877102
|
||||
|
||||
type CancelRequest struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CancelRequest) Frontend() {}
|
||||
|
||||
func (dst *CancelRequest) Decode(src []byte) error {
|
||||
if len(src) != 12 {
|
||||
return errors.Errorf("bad cancel request size")
|
||||
}
|
||||
|
||||
requestCode := binary.BigEndian.Uint32(src)
|
||||
|
||||
if requestCode != cancelRequestCode {
|
||||
return errors.Errorf("bad cancel request code")
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "CancelRequest",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const sslRequestNumber = 80877103
|
||||
|
||||
type SSLRequest struct {
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*SSLRequest) Frontend() {}
|
||||
|
||||
func (dst *SSLRequest) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.Errorf("ssl request too short")
|
||||
}
|
||||
|
||||
requestCode := binary.BigEndian.Uint32(src)
|
||||
|
||||
if requestCode != sslRequestNumber {
|
||||
return errors.Errorf("bad ssl request code")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *SSLRequest) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src SSLRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}{
|
||||
Type: "SSLRequest",
|
||||
})
|
||||
}
|
|
@ -9,10 +9,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolVersionNumber = 196608 // 3.0
|
||||
sslRequestNumber = 80877103
|
||||
)
|
||||
const ProtocolVersionNumber = 196608 // 3.0
|
||||
|
||||
type StartupMessage struct {
|
||||
ProtocolVersion uint32
|
||||
|
@ -32,10 +29,6 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
|||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||
rp := 4
|
||||
|
||||
if dst.ProtocolVersion == sslRequestNumber {
|
||||
return errors.Errorf("can't handle ssl connection request")
|
||||
}
|
||||
|
||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue