feat: Add Max Func to Limiter Middleware (#3070)

* feat: add max calculator to limiter middleware

* docs: update docs including the new parameter

* refactor: add new line before go code in docs

* fix: use crypto/rand instead of math/rand on tests

* test: add new test with zero set as limit

* fix: repeated tests failing when generating random limits

* fix: wrong type of MaxCalculator in docs

* feat: include max calculator in limiter_sliding

* refactor: rename MaxCalculator to MaxFunc

* docs: update docs with MaxFunc parameter

* tests: rename tests and add test for limiter sliding
pull/3076/head
Lucas Lemos 2024-07-23 18:00:37 -03:00 committed by Juan Calderon-Perez
parent d7d91598c5
commit c90219026f
5 changed files with 174 additions and 10 deletions

View File

@ -43,6 +43,9 @@ app.Use(limiter.New(limiter.Config{
return c.IP() == "127.0.0.1"
},
Max: 20,
MaxFunc: func(c fiber.Ctx) int {
return 20
},
Expiration: 30 * time.Second,
KeyGenerator: func(c fiber.Ctx) string {
return c.Get("x-forwarded-for")
@ -75,12 +78,28 @@ weightOfPreviousWindow = previous window's amount request * (whenNewWindow / Exp
rate = weightOfPreviousWindow + current window's amount request.
```
## Dynamic limit
You can also calculate the limit dynamically using the MaxFunc parameter. It's a function that receives the request's context as a parameter and allow you to calculate a different limit for each request separately.
Example:
```go
app.Use(limiter.New(limiter.Config{
MaxFunc: func(c fiber.Ctx) int {
return getUserLimit(ctx.Param("id"))
},
Expiration: 30 * time.Second,
}))
```
## Config
| Property | Type | Description | Default |
|:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| Max | `int` | Max number of recent connections during `Expiration` seconds before sending a 429 response. | 5 |
| MaxFunc | `func(fiber.Ctx) int` | A function to calculate the max number of recent connections during `Expiration` seconds before sending a 429 response. | A function which returns the cfg.Max |
| KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys, by default c.IP() is used. | A function using c.IP() as the default |
| Expiration | `time.Duration` | Expiration is the time on how long to keep records of requests in memory. | 1 * time.Minute |
| LimitReached | `fiber.Handler` | LimitReached is called when a request hits the limit. | A function sending 429 response |
@ -101,6 +120,9 @@ A custom store can be used if it implements the `Storage` interface - more detai
```go
var ConfigDefault = Config{
Max: 5,
MaxFunc: func(c fiber.Ctx) int {
return 5
},
Expiration: 1 * time.Minute,
KeyGenerator: func(c fiber.Ctx) string {
return c.IP()

View File

@ -22,6 +22,13 @@ type Config struct {
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// A function to dynamically calculate the max requests supported by the rate limiter middleware
//
// Default: func(c fiber.Ctx) int {
// return c.Max
// }
MaxFunc func(c fiber.Ctx) int
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c fiber.Ctx) string {
@ -101,5 +108,10 @@ func configDefault(config ...Config) Config {
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
if cfg.MaxFunc == nil {
cfg.MaxFunc = func(_ fiber.Ctx) int {
return cfg.Max
}
}
return cfg
}

View File

@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
// Generate max from generator, if no generator was provided the default value returned is 5
max := cfg.MaxFunc(c)
// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
return c.Next()
}
@ -60,7 +62,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
resetInSec := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - e.currHits
remaining := max - e.currHits
// Update storage
manager.set(key, e, cfg.Expiration)
@ -68,7 +70,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
// Check if hits exceed the max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitLimit, strconv.Itoa(max))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))

View File

@ -16,7 +16,6 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
@ -28,8 +27,11 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
// Generate max from generator, if no generator was provided the default value returned is 5
max := cfg.MaxFunc(c)
// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
return c.Next()
}
@ -127,7 +129,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitLimit, strconv.Itoa(max))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))

View File

@ -14,6 +14,132 @@ import (
"github.com/valyala/fasthttp"
)
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int { return 0 },
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int {
return 0
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
for i := 0; i <= 4; i++ {
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
}(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_With_Max_Func -race -v
func Test_Limiter_With_Max_Func(t *testing.T) {
t.Parallel()
app := fiber.New()
max := 10
app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int {
return max
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
for i := 0; i <= max-1; i++ {
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
}(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()