mirror of https://github.com/gofiber/fiber.git
✨ feat: Add Max Func to Limiter Middleware (#3070)
* feat: add max calculator to limiter middleware * docs: update docs including the new parameter * refactor: add new line before go code in docs * fix: use crypto/rand instead of math/rand on tests * test: add new test with zero set as limit * fix: repeated tests failing when generating random limits * fix: wrong type of MaxCalculator in docs * feat: include max calculator in limiter_sliding * refactor: rename MaxCalculator to MaxFunc * docs: update docs with MaxFunc parameter * tests: rename tests and add test for limiter slidingpull/3087/head
parent
486304d050
commit
011c8f8007
|
@ -43,6 +43,9 @@ app.Use(limiter.New(limiter.Config{
|
||||||
return c.IP() == "127.0.0.1"
|
return c.IP() == "127.0.0.1"
|
||||||
},
|
},
|
||||||
Max: 20,
|
Max: 20,
|
||||||
|
MaxFunc: func(c fiber.Ctx) int {
|
||||||
|
return 20
|
||||||
|
},
|
||||||
Expiration: 30 * time.Second,
|
Expiration: 30 * time.Second,
|
||||||
KeyGenerator: func(c fiber.Ctx) string {
|
KeyGenerator: func(c fiber.Ctx) string {
|
||||||
return c.Get("x-forwarded-for")
|
return c.Get("x-forwarded-for")
|
||||||
|
@ -75,12 +78,28 @@ weightOfPreviousWindow = previous window's amount request * (whenNewWindow / Exp
|
||||||
rate = weightOfPreviousWindow + current window's amount request.
|
rate = weightOfPreviousWindow + current window's amount request.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Dynamic limit
|
||||||
|
|
||||||
|
You can also calculate the limit dynamically using the MaxFunc parameter. It's a function that receives the request's context as a parameter and allow you to calculate a different limit for each request separately.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
app.Use(limiter.New(limiter.Config{
|
||||||
|
MaxFunc: func(c fiber.Ctx) int {
|
||||||
|
return getUserLimit(ctx.Param("id"))
|
||||||
|
},
|
||||||
|
Expiration: 30 * time.Second,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
## Config
|
## Config
|
||||||
|
|
||||||
| Property | Type | Description | Default |
|
| Property | Type | Description | Default |
|
||||||
|:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------|
|
|:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------|
|
||||||
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
|
||||||
| Max | `int` | Max number of recent connections during `Expiration` seconds before sending a 429 response. | 5 |
|
| Max | `int` | Max number of recent connections during `Expiration` seconds before sending a 429 response. | 5 |
|
||||||
|
| MaxFunc | `func(fiber.Ctx) int` | A function to calculate the max number of recent connections during `Expiration` seconds before sending a 429 response. | A function which returns the cfg.Max |
|
||||||
| KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys, by default c.IP() is used. | A function using c.IP() as the default |
|
| KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys, by default c.IP() is used. | A function using c.IP() as the default |
|
||||||
| Expiration | `time.Duration` | Expiration is the time on how long to keep records of requests in memory. | 1 * time.Minute |
|
| Expiration | `time.Duration` | Expiration is the time on how long to keep records of requests in memory. | 1 * time.Minute |
|
||||||
| LimitReached | `fiber.Handler` | LimitReached is called when a request hits the limit. | A function sending 429 response |
|
| LimitReached | `fiber.Handler` | LimitReached is called when a request hits the limit. | A function sending 429 response |
|
||||||
|
@ -101,6 +120,9 @@ A custom store can be used if it implements the `Storage` interface - more detai
|
||||||
```go
|
```go
|
||||||
var ConfigDefault = Config{
|
var ConfigDefault = Config{
|
||||||
Max: 5,
|
Max: 5,
|
||||||
|
MaxFunc: func(c fiber.Ctx) int {
|
||||||
|
return 5
|
||||||
|
},
|
||||||
Expiration: 1 * time.Minute,
|
Expiration: 1 * time.Minute,
|
||||||
KeyGenerator: func(c fiber.Ctx) string {
|
KeyGenerator: func(c fiber.Ctx) string {
|
||||||
return c.IP()
|
return c.IP()
|
||||||
|
|
|
@ -22,6 +22,13 @@ type Config struct {
|
||||||
// Optional. Default: nil
|
// Optional. Default: nil
|
||||||
Next func(c fiber.Ctx) bool
|
Next func(c fiber.Ctx) bool
|
||||||
|
|
||||||
|
// A function to dynamically calculate the max requests supported by the rate limiter middleware
|
||||||
|
//
|
||||||
|
// Default: func(c fiber.Ctx) int {
|
||||||
|
// return c.Max
|
||||||
|
// }
|
||||||
|
MaxFunc func(c fiber.Ctx) int
|
||||||
|
|
||||||
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
||||||
//
|
//
|
||||||
// Default: func(c fiber.Ctx) string {
|
// Default: func(c fiber.Ctx) string {
|
||||||
|
@ -101,5 +108,10 @@ func configDefault(config ...Config) Config {
|
||||||
if cfg.LimiterMiddleware == nil {
|
if cfg.LimiterMiddleware == nil {
|
||||||
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
||||||
}
|
}
|
||||||
|
if cfg.MaxFunc == nil {
|
||||||
|
cfg.MaxFunc = func(_ fiber.Ctx) int {
|
||||||
|
return cfg.Max
|
||||||
|
}
|
||||||
|
}
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||||
var (
|
var (
|
||||||
// Limiter variables
|
// Limiter variables
|
||||||
mux = &sync.RWMutex{}
|
mux = &sync.RWMutex{}
|
||||||
max = strconv.Itoa(cfg.Max)
|
|
||||||
expiration = uint64(cfg.Expiration.Seconds())
|
expiration = uint64(cfg.Expiration.Seconds())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c fiber.Ctx) error {
|
return func(c fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Generate max from generator, if no generator was provided the default value returned is 5
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
max := cfg.MaxFunc(c)
|
||||||
|
|
||||||
|
// Don't execute middleware if Next returns true or if the max is 0
|
||||||
|
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,7 +62,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||||
resetInSec := e.exp - ts
|
resetInSec := e.exp - ts
|
||||||
|
|
||||||
// Set how many hits we have left
|
// Set how many hits we have left
|
||||||
remaining := cfg.Max - e.currHits
|
remaining := max - e.currHits
|
||||||
|
|
||||||
// Update storage
|
// Update storage
|
||||||
manager.set(key, e, cfg.Expiration)
|
manager.set(key, e, cfg.Expiration)
|
||||||
|
@ -68,7 +70,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||||
// Unlock entry
|
// Unlock entry
|
||||||
mux.Unlock()
|
mux.Unlock()
|
||||||
|
|
||||||
// Check if hits exceed the cfg.Max
|
// Check if hits exceed the max
|
||||||
if remaining < 0 {
|
if remaining < 0 {
|
||||||
// Return response with Retry-After header
|
// Return response with Retry-After header
|
||||||
// https://tools.ietf.org/html/rfc6584
|
// https://tools.ietf.org/html/rfc6584
|
||||||
|
@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can continue, update RateLimit headers
|
// We can continue, update RateLimit headers
|
||||||
c.Set(xRateLimitLimit, max)
|
c.Set(xRateLimitLimit, strconv.Itoa(max))
|
||||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||||
var (
|
var (
|
||||||
// Limiter variables
|
// Limiter variables
|
||||||
mux = &sync.RWMutex{}
|
mux = &sync.RWMutex{}
|
||||||
max = strconv.Itoa(cfg.Max)
|
|
||||||
expiration = uint64(cfg.Expiration.Seconds())
|
expiration = uint64(cfg.Expiration.Seconds())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,8 +27,11 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||||
|
|
||||||
// Return new handler
|
// Return new handler
|
||||||
return func(c fiber.Ctx) error {
|
return func(c fiber.Ctx) error {
|
||||||
// Don't execute middleware if Next returns true
|
// Generate max from generator, if no generator was provided the default value returned is 5
|
||||||
if cfg.Next != nil && cfg.Next(c) {
|
max := cfg.MaxFunc(c)
|
||||||
|
|
||||||
|
// Don't execute middleware if Next returns true or if the max is 0
|
||||||
|
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,7 +129,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can continue, update RateLimit headers
|
// We can continue, update RateLimit headers
|
||||||
c.Set(xRateLimitLimit, max)
|
c.Set(xRateLimitLimit, strconv.Itoa(max))
|
||||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,132 @@ import (
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
|
||||||
|
func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Use(New(Config{
|
||||||
|
MaxFunc: func(_ fiber.Ctx) int { return 0 },
|
||||||
|
Expiration: 2 * time.Second,
|
||||||
|
SkipFailedRequests: false,
|
||||||
|
SkipSuccessfulRequests: false,
|
||||||
|
LimiterMiddleware: SlidingWindow{},
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Get("/:status", func(c fiber.Ctx) error {
|
||||||
|
if c.Params("status") == "fail" {
|
||||||
|
return c.SendStatus(400)
|
||||||
|
}
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
|
||||||
|
func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Use(New(Config{
|
||||||
|
MaxFunc: func(_ fiber.Ctx) int {
|
||||||
|
return 0
|
||||||
|
},
|
||||||
|
Expiration: 2 * time.Second,
|
||||||
|
Storage: memory.New(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Get("/", func(c fiber.Ctx) error {
|
||||||
|
return c.SendString("Hello tester!")
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i <= 4; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "Hello tester!", string(body))
|
||||||
|
}(&wg)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// go test -run Test_Limiter_With_Max_Func -race -v
|
||||||
|
func Test_Limiter_With_Max_Func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
app := fiber.New()
|
||||||
|
max := 10
|
||||||
|
|
||||||
|
app.Use(New(Config{
|
||||||
|
MaxFunc: func(_ fiber.Ctx) int {
|
||||||
|
return max
|
||||||
|
},
|
||||||
|
Expiration: 2 * time.Second,
|
||||||
|
Storage: memory.New(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Get("/", func(c fiber.Ctx) error {
|
||||||
|
return c.SendString("Hello tester!")
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i <= max-1; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "Hello tester!", string(body))
|
||||||
|
}(&wg)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// go test -run Test_Limiter_Concurrency_Store -race -v
|
// go test -run Test_Limiter_Concurrency_Store -race -v
|
||||||
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
Loading…
Reference in New Issue