mirror of https://github.com/gofiber/fiber.git
139 lines
4.0 KiB
Go
139 lines
4.0 KiB
Go
package limiter
|
|
|
|
import (
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/utils/v2"
|
|
)
|
|
|
|
type SlidingWindow struct{}
|
|
|
|
// New creates a new sliding window middleware handler
|
|
func (SlidingWindow) New(cfg Config) fiber.Handler {
|
|
var (
|
|
// Limiter variables
|
|
mux = &sync.RWMutex{}
|
|
expiration = uint64(cfg.Expiration.Seconds())
|
|
)
|
|
|
|
// Create manager to simplify storage operations ( see manager.go )
|
|
manager := newManager(cfg.Storage)
|
|
|
|
// Update timestamp every second
|
|
utils.StartTimeStampUpdater()
|
|
|
|
// Return new handler
|
|
return func(c fiber.Ctx) error {
|
|
// Generate maxRequests from generator, if no generator was provided the default value returned is 5
|
|
maxRequests := cfg.MaxFunc(c)
|
|
|
|
// Don't execute middleware if Next returns true or if the max is 0
|
|
if (cfg.Next != nil && cfg.Next(c)) || maxRequests == 0 {
|
|
return c.Next()
|
|
}
|
|
|
|
// Get key from request
|
|
key := cfg.KeyGenerator(c)
|
|
|
|
// Lock entry
|
|
mux.Lock()
|
|
|
|
// Get entry from pool and release when finished
|
|
e := manager.get(key)
|
|
|
|
// Get timestamp
|
|
ts := uint64(utils.Timestamp())
|
|
|
|
// Set expiration if entry does not exist
|
|
if e.exp == 0 {
|
|
e.exp = ts + expiration
|
|
} else if ts >= e.exp {
|
|
// The entry has expired, handle the expiration.
|
|
// Set the prevHits to the current hits and reset the hits to 0.
|
|
e.prevHits = e.currHits
|
|
|
|
// Reset the current hits to 0.
|
|
e.currHits = 0
|
|
|
|
// Check how much into the current window it currently is and sets the
|
|
// expiry based on that, otherwise this would only reset on
|
|
// the next request and not show the correct expiry.
|
|
elapsed := ts - e.exp
|
|
if elapsed >= expiration {
|
|
e.exp = ts + expiration
|
|
} else {
|
|
e.exp = ts + expiration - elapsed
|
|
}
|
|
}
|
|
|
|
// Increment hits
|
|
e.currHits++
|
|
|
|
// Calculate when it resets in seconds
|
|
resetInSec := e.exp - ts
|
|
|
|
// weight = time until current window reset / total window length
|
|
weight := float64(resetInSec) / float64(expiration)
|
|
|
|
// rate = request count in previous window - weight + request count in current window
|
|
rate := int(float64(e.prevHits)*weight) + e.currHits
|
|
|
|
// Calculate how many hits can be made based on the current rate
|
|
remaining := cfg.Max - rate
|
|
|
|
// Update storage. Garbage collect when the next window ends.
|
|
// |--------------------------|--------------------------|
|
|
// ^ ^ ^ ^
|
|
// ts e.exp End sample window End next window
|
|
// <------------>
|
|
// Reset In Sec
|
|
// resetInSec = e.exp - ts - time until end of current window.
|
|
// duration + expiration = end of next window.
|
|
// Because we don't want to garbage collect in the middle of a window
|
|
// we add the expiration to the duration.
|
|
// Otherwise after the end of "sample window", attackers could launch
|
|
// a new request with the full window length.
|
|
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second) //nolint:gosec // Not a concern
|
|
|
|
// Unlock entry
|
|
mux.Unlock()
|
|
|
|
// Check if hits exceed the cfg.Max
|
|
if remaining < 0 {
|
|
// Return response with Retry-After header
|
|
// https://tools.ietf.org/html/rfc6584
|
|
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
|
|
|
// Call LimitReached handler
|
|
return cfg.LimitReached(c)
|
|
}
|
|
|
|
// Continue stack for reaching c.Response().StatusCode()
|
|
// Store err for returning
|
|
err := c.Next()
|
|
|
|
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
|
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
|
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
|
// Lock entry
|
|
mux.Lock()
|
|
e = manager.get(key)
|
|
e.currHits--
|
|
remaining++
|
|
manager.set(key, e, cfg.Expiration)
|
|
// Unlock entry
|
|
mux.Unlock()
|
|
}
|
|
|
|
// We can continue, update RateLimit headers
|
|
c.Set(xRateLimitLimit, strconv.Itoa(maxRequests))
|
|
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
|
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
|
|
|
return err
|
|
}
|
|
}
|