mirror of https://github.com/gofiber/fiber.git
✨ feat: Add Max Func to Limiter Middleware (#3070)
* feat: add max calculator to limiter middleware * docs: update docs including the new parameter * refactor: add new line before go code in docs * fix: use crypto/rand instead of math/rand on tests * test: add new test with zero set as limit * fix: repeated tests failing when generating random limits * fix: wrong type of MaxCalculator in docs * feat: include max calculator in limiter_sliding * refactor: rename MaxCalculator to MaxFunc * docs: update docs with MaxFunc parameter * tests: rename tests and add test for limiter slidingpull/3076/head
parent
d7d91598c5
commit
c90219026f
|
@ -43,6 +43,9 @@ app.Use(limiter.New(limiter.Config{
|
|||
return c.IP() == "127.0.0.1"
|
||||
},
|
||||
Max: 20,
|
||||
MaxFunc: func(c fiber.Ctx) int {
|
||||
return 20
|
||||
},
|
||||
Expiration: 30 * time.Second,
|
||||
KeyGenerator: func(c fiber.Ctx) string {
|
||||
return c.Get("x-forwarded-for")
|
||||
|
@ -75,12 +78,28 @@ weightOfPreviousWindow = previous window's amount request * (whenNewWindow / Exp
|
|||
rate = weightOfPreviousWindow + current window's amount request.
|
||||
```
|
||||
|
||||
## Dynamic limit
|
||||
|
||||
You can also calculate the limit dynamically using the MaxFunc parameter. It's a function that receives the request's context as a parameter and allow you to calculate a different limit for each request separately.
|
||||
|
||||
Example:
|
||||
|
||||
```go
|
||||
app.Use(limiter.New(limiter.Config{
|
||||
MaxFunc: func(c fiber.Ctx) int {
|
||||
return getUserLimit(ctx.Param("id"))
|
||||
},
|
||||
Expiration: 30 * time.Second,
|
||||
}))
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
| Property | Type | Description | Default |
|
||||
|:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------|
|
||||
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
|
||||
| Max | `int` | Max number of recent connections during `Expiration` seconds before sending a 429 response. | 5 |
|
||||
| MaxFunc | `func(fiber.Ctx) int` | A function to calculate the max number of recent connections during `Expiration` seconds before sending a 429 response. | A function which returns the cfg.Max |
|
||||
| KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys, by default c.IP() is used. | A function using c.IP() as the default |
|
||||
| Expiration | `time.Duration` | Expiration is the time on how long to keep records of requests in memory. | 1 * time.Minute |
|
||||
| LimitReached | `fiber.Handler` | LimitReached is called when a request hits the limit. | A function sending 429 response |
|
||||
|
@ -101,6 +120,9 @@ A custom store can be used if it implements the `Storage` interface - more detai
|
|||
```go
|
||||
var ConfigDefault = Config{
|
||||
Max: 5,
|
||||
MaxFunc: func(c fiber.Ctx) int {
|
||||
return 5
|
||||
},
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c fiber.Ctx) string {
|
||||
return c.IP()
|
||||
|
|
|
@ -22,6 +22,13 @@ type Config struct {
|
|||
// Optional. Default: nil
|
||||
Next func(c fiber.Ctx) bool
|
||||
|
||||
// A function to dynamically calculate the max requests supported by the rate limiter middleware
|
||||
//
|
||||
// Default: func(c fiber.Ctx) int {
|
||||
// return c.Max
|
||||
// }
|
||||
MaxFunc func(c fiber.Ctx) int
|
||||
|
||||
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c fiber.Ctx) string {
|
||||
|
@ -101,5 +108,10 @@ func configDefault(config ...Config) Config {
|
|||
if cfg.LimiterMiddleware == nil {
|
||||
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
||||
}
|
||||
if cfg.MaxFunc == nil {
|
||||
cfg.MaxFunc = func(_ fiber.Ctx) int {
|
||||
return cfg.Max
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
|||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
|
@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
|||
|
||||
// Return new handler
|
||||
return func(c fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
// Generate max from generator, if no generator was provided the default value returned is 5
|
||||
max := cfg.MaxFunc(c)
|
||||
|
||||
// Don't execute middleware if Next returns true or if the max is 0
|
||||
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
@ -60,7 +62,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
|||
resetInSec := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - e.currHits
|
||||
remaining := max - e.currHits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
@ -68,7 +70,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
|||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
// Check if hits exceed the max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
|
@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitLimit, strconv.Itoa(max))
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
|
|||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
|
@ -28,8 +27,11 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
|
|||
|
||||
// Return new handler
|
||||
return func(c fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
// Generate max from generator, if no generator was provided the default value returned is 5
|
||||
max := cfg.MaxFunc(c)
|
||||
|
||||
// Don't execute middleware if Next returns true or if the max is 0
|
||||
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
@ -127,7 +129,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitLimit, strconv.Itoa(max))
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
|
|
|
@ -14,6 +14,132 @@ import (
|
|||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
|
||||
func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
MaxFunc: func(_ fiber.Ctx) int { return 0 },
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
|
||||
func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
MaxFunc: func(_ fiber.Ctx) int {
|
||||
return 0
|
||||
},
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i <= 4; i++ {
|
||||
wg.Add(1)
|
||||
go func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello tester!", string(body))
|
||||
}(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_With_Max_Func -race -v
|
||||
func Test_Limiter_With_Max_Func(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
max := 10
|
||||
|
||||
app.Use(New(Config{
|
||||
MaxFunc: func(_ fiber.Ctx) int {
|
||||
return max
|
||||
},
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i <= max-1; i++ {
|
||||
wg.Add(1)
|
||||
go func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello tester!", string(body))
|
||||
}(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Concurrency_Store -race -v
|
||||
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
Loading…
Reference in New Issue