feat(pgproto3/backend): add a SetMaxBodyLen to limit the max body length for the receive

pull/1856/head
jeremy.spriet 2023-12-19 18:19:43 +01:00 committed by Jack Christensen
parent 9ab9e3c40b
commit 603c8c1e90
3 changed files with 40 additions and 0 deletions

View File

@ -38,6 +38,7 @@ type Backend struct {
terminate Terminate
bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte
partialMsg bool
authType uint32
@ -158,6 +159,9 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.partialMsg = true
}
@ -260,3 +264,12 @@ func (b *Backend) SetAuthType(authType uint32) error {
return nil
}
// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return
// an error. This is useful for protecting against malicious clients that send large messages with the intent of
// causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@ -120,3 +120,21 @@ func TestStartupMessage(t *testing.T) {
}
})
}
func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 10, 10})
backend := pgproto3.NewBackend(server, nil)
// Set max body len to 5
backend.SetMaxBodyLen(5)
// Receive regular msg
msg, err := backend.Receive()
assert.Nil(t, msg)
var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
assert.ErrorAs(t, err, &invalidBodyLenErr)
}

View File

@ -70,6 +70,15 @@ func (e *writeError) Unwrap() error {
return e.err
}
type ExceededMaxBodyLenErr struct {
maxExpectedBodyLen int
actualBodyLen int
}
func (e *ExceededMaxBodyLenErr) Error() string {
return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.maxExpectedBodyLen, e.actualBodyLen)
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {