📦 Add storage interface to limiter middleware

This commit is contained in:
Tom 2020-10-08 19:54:39 +01:00
parent 8745ad7fc2
commit e8e706c2b8
No known key found for this signature in database
GPG Key ID: D3E7EAA31B39637E
2 changed files with 126 additions and 14 deletions

View File

@ -1,6 +1,8 @@
package limiter
import (
"bytes"
"encoding/gob"
"strconv"
"sync"
"sync/atomic"
@ -39,6 +41,11 @@ type Config struct {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store Storage
}
// ConfigDefault is the default config
@ -52,6 +59,13 @@ var ConfigDefault = Config{
LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
Store: &defaultStore{stmap: map[string][]byte{}},
}
// trackedSession is the type used for session tracking
type trackedSession struct {
Hits int
ResetTime uint64
}
// X-RateLimit-* headers
@ -86,12 +100,15 @@ func New(config ...Config) fiber.Handler {
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.Store == nil {
cfg.Store = ConfigDefault.Store
}
}
// Limiter settings
var max = strconv.Itoa(cfg.Max)
var hits = make(map[string]int)
var reset = make(map[string]uint64)
// var hits = make(map[string]int)
// var reset = make(map[string]uint64)
var timestamp = uint64(time.Now().Unix())
var duration = uint64(cfg.Duration.Seconds())
@ -116,33 +133,67 @@ func New(config ...Config) fiber.Handler {
// Get key (default is the remote IP)
key := cfg.Key(c)
// Lock map
// Lock mux (prevents values changing between retrieval and reassignment, which can and does
// break things)
mux.Lock()
// Load data from store
fromStore, err := cfg.Store.Get(key)
if err != nil {
return err
}
// Decode data from store
var session trackedSession
if len(fromStore) == 0 {
// Assume item not found.
session = trackedSession{}
} else {
// Decode bytes using gob
var buf bytes.Buffer
_, _ = buf.Write(fromStore)
dec := gob.NewDecoder(&buf)
err := dec.Decode(&session)
if err != nil {
return err
}
}
// Set unix timestamp if not exist
ts := atomic.LoadUint64(&timestamp)
if reset[key] == 0 {
reset[key] = ts + duration
} else if ts >= reset[key] {
hits[key] = 0
reset[key] = ts + duration
if session.ResetTime == 0 {
session.ResetTime = ts + duration
} else if ts >= session.ResetTime {
session.Hits = 0
session.ResetTime = ts + duration
}
// Increment key hits
hits[key]++
session.Hits++
// Convert session struct into bytes
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err = enc.Encode(session)
if err != nil {
return err
}
// Store those bytes
cfg.Store.Set(key, buf.Bytes(), time.Duration(0))
// Get current hits
hitCount := hits[key]
hitCount := session.Hits
// Calculate when it resets in seconds
resetTime := reset[key] - ts
// Unlock map
mux.Unlock()
resetTime := session.ResetTime - ts
// Set how many hits we have left
remaining := cfg.Max - hitCount
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header

View File

@ -0,0 +1,61 @@
package limiter
import (
"sync"
"time"
)
// Storage interface implemented by providers
type Storage interface {
// Get session value
Get(id string) ([]byte, error)
// Set session value
Set(id string, value []byte, exp time.Duration) error
// Delete session value
Delete(id string) error
// Clear clears the store
Clear() error
}
type defaultStore struct {
stmap map[string][]byte
mutex sync.Mutex
}
func (s *defaultStore) Get(id string) ([]byte, error) {
s.mutex.Lock()
val, ok := s.stmap[id]
s.mutex.Unlock()
if !ok {
return []byte{}, nil
} else {
return val, nil
}
}
func (s *defaultStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()
return nil
}
func (s *defaultStore) Clear() error {
s.mutex.Lock()
s.stmap = map[string][]byte{}
s.mutex.Unlock()
return nil
}
func (s *defaultStore) Delete(id string) error {
s.mutex.Lock()
_, ok := s.stmap[id]
if ok {
delete(s.stmap, id)
}
s.mutex.Unlock()
return nil
}