mirror of
https://github.com/gofiber/fiber.git
synced 2025-05-31 11:52:41 +00:00
📦 Add storage interface to limiter middleware
This commit is contained in:
parent
8745ad7fc2
commit
e8e706c2b8
@ -1,6 +1,8 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -39,6 +41,11 @@ type Config struct {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Store Storage
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
@ -52,6 +59,13 @@ var ConfigDefault = Config{
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
Store: &defaultStore{stmap: map[string][]byte{}},
|
||||
}
|
||||
|
||||
// trackedSession is the type used for session tracking
|
||||
type trackedSession struct {
|
||||
Hits int
|
||||
ResetTime uint64
|
||||
}
|
||||
|
||||
// X-RateLimit-* headers
|
||||
@ -86,12 +100,15 @@ func New(config ...Config) fiber.Handler {
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.Store == nil {
|
||||
cfg.Store = ConfigDefault.Store
|
||||
}
|
||||
}
|
||||
|
||||
// Limiter settings
|
||||
var max = strconv.Itoa(cfg.Max)
|
||||
var hits = make(map[string]int)
|
||||
var reset = make(map[string]uint64)
|
||||
// var hits = make(map[string]int)
|
||||
// var reset = make(map[string]uint64)
|
||||
var timestamp = uint64(time.Now().Unix())
|
||||
var duration = uint64(cfg.Duration.Seconds())
|
||||
|
||||
@ -116,33 +133,67 @@ func New(config ...Config) fiber.Handler {
|
||||
// Get key (default is the remote IP)
|
||||
key := cfg.Key(c)
|
||||
|
||||
// Lock map
|
||||
// Lock mux (prevents values changing between retrieval and reassignment, which can and does
|
||||
// break things)
|
||||
mux.Lock()
|
||||
|
||||
// Load data from store
|
||||
fromStore, err := cfg.Store.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode data from store
|
||||
var session trackedSession
|
||||
|
||||
if len(fromStore) == 0 {
|
||||
// Assume item not found.
|
||||
session = trackedSession{}
|
||||
} else {
|
||||
// Decode bytes using gob
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(fromStore)
|
||||
dec := gob.NewDecoder(&buf)
|
||||
err := dec.Decode(&session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Set unix timestamp if not exist
|
||||
ts := atomic.LoadUint64(×tamp)
|
||||
if reset[key] == 0 {
|
||||
reset[key] = ts + duration
|
||||
} else if ts >= reset[key] {
|
||||
hits[key] = 0
|
||||
reset[key] = ts + duration
|
||||
if session.ResetTime == 0 {
|
||||
session.ResetTime = ts + duration
|
||||
} else if ts >= session.ResetTime {
|
||||
session.Hits = 0
|
||||
session.ResetTime = ts + duration
|
||||
}
|
||||
|
||||
// Increment key hits
|
||||
hits[key]++
|
||||
session.Hits++
|
||||
|
||||
// Convert session struct into bytes
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err = enc.Encode(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store those bytes
|
||||
cfg.Store.Set(key, buf.Bytes(), time.Duration(0))
|
||||
|
||||
// Get current hits
|
||||
hitCount := hits[key]
|
||||
hitCount := session.Hits
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetTime := reset[key] - ts
|
||||
|
||||
// Unlock map
|
||||
mux.Unlock()
|
||||
resetTime := session.ResetTime - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - hitCount
|
||||
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
|
61
middleware/limiter/store.go
Normal file
61
middleware/limiter/store.go
Normal file
@ -0,0 +1,61 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Storage interface implemented by providers
|
||||
type Storage interface {
|
||||
// Get session value
|
||||
Get(id string) ([]byte, error)
|
||||
// Set session value
|
||||
Set(id string, value []byte, exp time.Duration) error
|
||||
// Delete session value
|
||||
Delete(id string) error
|
||||
// Clear clears the store
|
||||
Clear() error
|
||||
}
|
||||
|
||||
type defaultStore struct {
|
||||
stmap map[string][]byte
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func (s *defaultStore) Get(id string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
val, ok := s.stmap[id]
|
||||
s.mutex.Unlock()
|
||||
if !ok {
|
||||
return []byte{}, nil
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultStore) Set(id string, val []byte, _ time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
s.stmap[id] = val
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *defaultStore) Clear() error {
|
||||
s.mutex.Lock()
|
||||
s.stmap = map[string][]byte{}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *defaultStore) Delete(id string) error {
|
||||
s.mutex.Lock()
|
||||
_, ok := s.stmap[id]
|
||||
if ok {
|
||||
delete(s.stmap, id)
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user