From e8e706c2b80594b2c4fdaad9a443dcb58782e8d0 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 8 Oct 2020 19:54:39 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=A6=20Add=20storage=20interface=20to?= =?UTF-8?q?=20limiter=20middleware?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/limiter/limiter.go | 79 ++++++++++++++++++++++++++++------- middleware/limiter/store.go | 61 +++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 14 deletions(-) create mode 100644 middleware/limiter/store.go diff --git a/middleware/limiter/limiter.go b/middleware/limiter/limiter.go index c3eb3fe0..f01249e8 100644 --- a/middleware/limiter/limiter.go +++ b/middleware/limiter/limiter.go @@ -1,6 +1,8 @@ package limiter import ( + "bytes" + "encoding/gob" "strconv" "sync" "sync/atomic" @@ -39,6 +41,11 @@ type Config struct { // return c.SendStatus(fiber.StatusTooManyRequests) // } LimitReached fiber.Handler + + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Store Storage } // ConfigDefault is the default config @@ -52,6 +59,13 @@ var ConfigDefault = Config{ LimitReached: func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTooManyRequests) }, + Store: &defaultStore{stmap: map[string][]byte{}}, +} + +// trackedSession is the type used for session tracking +type trackedSession struct { + Hits int + ResetTime uint64 } // X-RateLimit-* headers @@ -86,12 +100,15 @@ func New(config ...Config) fiber.Handler { if cfg.LimitReached == nil { cfg.LimitReached = ConfigDefault.LimitReached } + if cfg.Store == nil { + cfg.Store = ConfigDefault.Store + } } // Limiter settings var max = strconv.Itoa(cfg.Max) - var hits = make(map[string]int) - var reset = make(map[string]uint64) + // var hits = make(map[string]int) + // var reset = make(map[string]uint64) var timestamp = uint64(time.Now().Unix()) var duration = uint64(cfg.Duration.Seconds()) @@ -116,33 +133,67 @@ func New(config ...Config) fiber.Handler { // Get key (default is the remote IP) key := cfg.Key(c) - // Lock map + // Lock mux (prevents values changing between retrieval and reassignment, which can and does + // break things) mux.Lock() + // Load data from store + fromStore, err := cfg.Store.Get(key) + if err != nil { + return err + } + + // Decode data from store + var session trackedSession + + if len(fromStore) == 0 { + // Assume item not found. + session = trackedSession{} + } else { + // Decode bytes using gob + var buf bytes.Buffer + _, _ = buf.Write(fromStore) + dec := gob.NewDecoder(&buf) + err := dec.Decode(&session) + if err != nil { + return err + } + } + // Set unix timestamp if not exist ts := atomic.LoadUint64(×tamp) - if reset[key] == 0 { - reset[key] = ts + duration - } else if ts >= reset[key] { - hits[key] = 0 - reset[key] = ts + duration + if session.ResetTime == 0 { + session.ResetTime = ts + duration + } else if ts >= session.ResetTime { + session.Hits = 0 + session.ResetTime = ts + duration } // Increment key hits - hits[key]++ + session.Hits++ + + // Convert session struct into bytes + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err = enc.Encode(session) + if err != nil { + return err + } + + // Store those bytes + cfg.Store.Set(key, buf.Bytes(), time.Duration(0)) // Get current hits - hitCount := hits[key] + hitCount := session.Hits // Calculate when it resets in seconds - resetTime := reset[key] - ts - - // Unlock map - mux.Unlock() + resetTime := session.ResetTime - ts // Set how many hits we have left remaining := cfg.Max - hitCount + mux.Unlock() + // Check if hits exceed the cfg.Max if remaining < 0 { // Return response with Retry-After header diff --git a/middleware/limiter/store.go b/middleware/limiter/store.go new file mode 100644 index 00000000..3e070741 --- /dev/null +++ b/middleware/limiter/store.go @@ -0,0 +1,61 @@ +package limiter + +import ( + "sync" + "time" +) + +// Storage interface implemented by providers +type Storage interface { + // Get session value + Get(id string) ([]byte, error) + // Set session value + Set(id string, value []byte, exp time.Duration) error + // Delete session value + Delete(id string) error + // Clear clears the store + Clear() error +} + +type defaultStore struct { + stmap map[string][]byte + mutex sync.Mutex +} + +func (s *defaultStore) Get(id string) ([]byte, error) { + s.mutex.Lock() + val, ok := s.stmap[id] + s.mutex.Unlock() + if !ok { + return []byte{}, nil + } else { + return val, nil + } +} + +func (s *defaultStore) Set(id string, val []byte, _ time.Duration) error { + s.mutex.Lock() + s.stmap[id] = val + s.mutex.Unlock() + + return nil +} + +func (s *defaultStore) Clear() error { + s.mutex.Lock() + s.stmap = map[string][]byte{} + s.mutex.Unlock() + + return nil +} + +func (s *defaultStore) Delete(id string) error { + s.mutex.Lock() + _, ok := s.stmap[id] + if ok { + delete(s.stmap, id) + } + s.mutex.Unlock() + + return nil +}