Support SSLRequest and CancelRequest

query-exec-mode
Jack Christensen 2019-08-31 11:48:01 -05:00
parent 76538434cf
commit 1ba5dcbe01
4 changed files with 138 additions and 16 deletions

View File

@ -2,6 +2,7 @@ package pgproto3
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -14,6 +15,7 @@ type Backend struct {
// Frontend message flyweights // Frontend message flyweights
bind Bind bind Bind
cancelRequest CancelRequest
_close Close _close Close
copyFail CopyFail copyFail CopyFail
describe Describe describe Describe
@ -22,6 +24,7 @@ type Backend struct {
parse Parse parse Parse
passwordMessage PasswordMessage passwordMessage PasswordMessage
query Query query Query
sslRequest SSLRequest
startupMessage StartupMessage startupMessage StartupMessage
sync Sync sync Sync
terminate Terminate terminate Terminate
@ -42,9 +45,10 @@ func (b *Backend) Send(msg BackendMessage) error {
return err return err
} }
// ReceiveStartupMessage receives the initial startup message. This method is used of the normal Receive method // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
// because StartupMessage and SSLRequest are "special" and do not include the message type as the first byte. // because the initial connection message is "special" and does not include the message type as the first byte. This
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { // will return either a StartupMessage, SSLRequest, or CancelRequest.
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
buf, err := b.cr.Next(4) buf, err := b.cr.Next(4)
if err != nil { if err != nil {
return nil, err return nil, err
@ -56,12 +60,30 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
return nil, err return nil, err
} }
code := binary.BigEndian.Uint32(buf)
switch code {
case ProtocolVersionNumber:
err = b.startupMessage.Decode(buf) err = b.startupMessage.Decode(buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &b.startupMessage, nil return &b.startupMessage, nil
case sslRequestNumber:
err = b.sslRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.sslRequest, nil
case cancelRequestCode:
err = b.cancelRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.cancelRequest, nil
default:
return nil, fmt.Errorf("unknown startup message code: %d", code)
}
} }
// Receive receives a message from the frontend. // Receive receives a message from the frontend.

58
cancel_request.go Normal file
View File

@ -0,0 +1,58 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
"github.com/pkg/errors"
)
const cancelRequestCode = 80877102
type CancelRequest struct {
ProcessID uint32
SecretKey uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CancelRequest) Frontend() {}
func (dst *CancelRequest) Decode(src []byte) error {
if len(src) != 12 {
return errors.Errorf("bad cancel request size")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != cancelRequestCode {
return errors.Errorf("bad cancel request code")
}
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CancelRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "CancelRequest",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

49
ssl_request.go Normal file
View File

@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
"github.com/pkg/errors"
)
const sslRequestNumber = 80877103
type SSLRequest struct {
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*SSLRequest) Frontend() {}
func (dst *SSLRequest) Decode(src []byte) error {
if len(src) < 4 {
return errors.Errorf("ssl request too short")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != sslRequestNumber {
return errors.Errorf("bad ssl request code")
}
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *SSLRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, sslRequestNumber)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src SSLRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "SSLRequest",
})
}

View File

@ -9,10 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const ( const ProtocolVersionNumber = 196608 // 3.0
ProtocolVersionNumber = 196608 // 3.0
sslRequestNumber = 80877103
)
type StartupMessage struct { type StartupMessage struct {
ProtocolVersion uint32 ProtocolVersion uint32
@ -32,10 +29,6 @@ func (dst *StartupMessage) Decode(src []byte) error {
dst.ProtocolVersion = binary.BigEndian.Uint32(src) dst.ProtocolVersion = binary.BigEndian.Uint32(src)
rp := 4 rp := 4
if dst.ProtocolVersion == sslRequestNumber {
return errors.Errorf("can't handle ssl connection request")
}
if dst.ProtocolVersion != ProtocolVersionNumber { if dst.ProtocolVersion != ProtocolVersionNumber {
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
} }