mirror of https://github.com/jackc/pgx.git
feat(pgproto3/backend): add a SetMaxBodyLen to limit the max body length for the receive
parent
9ab9e3c40b
commit
603c8c1e90
|
@ -38,6 +38,7 @@ type Backend struct {
|
|||
terminate Terminate
|
||||
|
||||
bodyLen int
|
||||
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
authType uint32
|
||||
|
@ -158,6 +159,9 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
|||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
|
||||
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
|
||||
}
|
||||
b.partialMsg = true
|
||||
}
|
||||
|
||||
|
@ -260,3 +264,12 @@ func (b *Backend) SetAuthType(authType uint32) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return
|
||||
// an error. This is useful for protecting against malicious clients that send large messages with the intent of
|
||||
// causing memory exhaustion.
|
||||
// The default value is 0.
|
||||
// If maxBodyLen is 0, then no maximum is enforced.
|
||||
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
|
||||
b.maxBodyLen = maxBodyLen
|
||||
}
|
||||
|
|
|
@ -120,3 +120,21 @@ func TestStartupMessage(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push([]byte{'Q', 0, 0, 10, 10})
|
||||
|
||||
backend := pgproto3.NewBackend(server, nil)
|
||||
|
||||
// Set max body len to 5
|
||||
backend.SetMaxBodyLen(5)
|
||||
|
||||
// Receive regular msg
|
||||
msg, err := backend.Receive()
|
||||
assert.Nil(t, msg)
|
||||
var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
|
||||
assert.ErrorAs(t, err, &invalidBodyLenErr)
|
||||
}
|
||||
|
|
|
@ -70,6 +70,15 @@ func (e *writeError) Unwrap() error {
|
|||
return e.err
|
||||
}
|
||||
|
||||
type ExceededMaxBodyLenErr struct {
|
||||
maxExpectedBodyLen int
|
||||
actualBodyLen int
|
||||
}
|
||||
|
||||
func (e *ExceededMaxBodyLenErr) Error() string {
|
||||
return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.maxExpectedBodyLen, e.actualBodyLen)
|
||||
}
|
||||
|
||||
// getValueFromJSON gets the value from a protocol message representation in JSON.
|
||||
func getValueFromJSON(v map[string]string) ([]byte, error) {
|
||||
if v == nil {
|
||||
|
|
Loading…
Reference in New Issue