mirror of
https://github.com/gofiber/fiber.git
synced 2025-05-31 11:52:41 +00:00
💼 implement Storage
This commit is contained in:
parent
32fdbf0ddf
commit
ecdda95e15
162
middleware/cache/cache.go
vendored
162
middleware/cache/cache.go
vendored
@ -5,6 +5,7 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@ -26,6 +27,15 @@ type Config struct {
|
|||||||
//
|
//
|
||||||
// Optional. Default: false
|
// Optional. Default: false
|
||||||
CacheControl bool
|
CacheControl bool
|
||||||
|
|
||||||
|
// Store is used to store the state of the middleware
|
||||||
|
//
|
||||||
|
// Default: an in memory store for this process only
|
||||||
|
Store fiber.Storage
|
||||||
|
|
||||||
|
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||||
|
// execution time down.
|
||||||
|
defaultStore bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigDefault is the default config
|
// ConfigDefault is the default config
|
||||||
@ -33,6 +43,7 @@ var ConfigDefault = Config{
|
|||||||
Next: nil,
|
Next: nil,
|
||||||
Expiration: 1 * time.Minute,
|
Expiration: 1 * time.Minute,
|
||||||
CacheControl: false,
|
CacheControl: false,
|
||||||
|
defaultStore: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// cache is the manager to store the cached responses
|
// cache is the manager to store the cached responses
|
||||||
@ -42,14 +53,6 @@ type cache struct {
|
|||||||
expiration int64
|
expiration int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// entry defines the cached response
|
|
||||||
type entry struct {
|
|
||||||
body []byte
|
|
||||||
contentType []byte
|
|
||||||
statusCode int
|
|
||||||
expiration int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new middleware handler
|
// New creates a new middleware handler
|
||||||
func New(config ...Config) fiber.Handler {
|
func New(config ...Config) fiber.Handler {
|
||||||
// Set default config
|
// Set default config
|
||||||
@ -66,8 +69,29 @@ func New(config ...Config) fiber.Handler {
|
|||||||
if int(cfg.Expiration.Seconds()) == 0 {
|
if int(cfg.Expiration.Seconds()) == 0 {
|
||||||
cfg.Expiration = ConfigDefault.Expiration
|
cfg.Expiration = ConfigDefault.Expiration
|
||||||
}
|
}
|
||||||
|
if cfg.Store == nil {
|
||||||
|
cfg.defaultStore = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Cache settings
|
||||||
|
timestamp = uint64(time.Now().Unix())
|
||||||
|
expiration = uint64(cfg.Expiration.Seconds())
|
||||||
|
mux = &sync.RWMutex{}
|
||||||
|
|
||||||
|
// Default store logic (if no Store is provided)
|
||||||
|
entries = make(map[string]entry)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Update timestamp every second
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
atomic.StoreUint64(×tamp, uint64(time.Now().Unix()))
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Nothing to cache
|
// Nothing to cache
|
||||||
if int(cfg.Expiration.Seconds()) < 0 {
|
if int(cfg.Expiration.Seconds()) < 0 {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
@ -75,23 +99,18 @@ func New(config ...Config) fiber.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize db
|
|
||||||
db := &cache{
|
|
||||||
entries: make(map[string]entry),
|
|
||||||
expiration: int64(cfg.Expiration.Seconds()),
|
|
||||||
}
|
|
||||||
// Remove expired entries
|
// Remove expired entries
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
// GC the entries every 10 seconds to avoid
|
// GC the entries every 10 seconds
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
db.Lock()
|
mux.Lock()
|
||||||
for k := range db.entries {
|
for k := range entries {
|
||||||
if time.Now().Unix() >= db.entries[k].expiration {
|
if atomic.LoadUint64(×tamp) >= entries[k].exp {
|
||||||
delete(db.entries, k)
|
delete(entries, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.Unlock()
|
mux.Unlock()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -110,28 +129,65 @@ func New(config ...Config) fiber.Handler {
|
|||||||
// Get key from request
|
// Get key from request
|
||||||
key := c.Path()
|
key := c.Path()
|
||||||
|
|
||||||
// Find cached entry
|
// Create new entry
|
||||||
db.RLock()
|
entry := entry{}
|
||||||
resp, ok := db.entries[key]
|
|
||||||
db.RUnlock()
|
// Lock entry
|
||||||
if ok {
|
mux.Lock()
|
||||||
// Check if entry is expired
|
defer mux.Unlock()
|
||||||
if time.Now().Unix() >= resp.expiration {
|
|
||||||
db.Lock()
|
// Check if we need to use the default in-memory storage
|
||||||
delete(db.entries, key)
|
if cfg.defaultStore {
|
||||||
db.Unlock()
|
entry = entries[key]
|
||||||
} else {
|
|
||||||
// Set response headers from cache
|
} else {
|
||||||
c.Response().SetBodyRaw(resp.body)
|
// Load data from store
|
||||||
c.Response().SetStatusCode(resp.statusCode)
|
storeEntry, err := cfg.Store.Get(key)
|
||||||
c.Response().Header.SetContentTypeBytes(resp.contentType)
|
if err != nil {
|
||||||
// Set Cache-Control header if enabled
|
return err
|
||||||
if cfg.CacheControl {
|
|
||||||
maxAge := strconv.FormatInt(resp.expiration-time.Now().Unix(), 10)
|
|
||||||
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only decode if we found an entry
|
||||||
|
if len(storeEntry) > 0 {
|
||||||
|
// Decode bytes using msgp
|
||||||
|
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get timestamp
|
||||||
|
ts := atomic.LoadUint64(×tamp)
|
||||||
|
|
||||||
|
// Set expiration if entry does not exist
|
||||||
|
if entry.exp == 0 {
|
||||||
|
entry.exp = ts + expiration
|
||||||
|
|
||||||
|
} else if ts >= entry.exp {
|
||||||
|
// Check if entry is expired
|
||||||
|
// Use default memory storage
|
||||||
|
if cfg.defaultStore {
|
||||||
|
delete(entries, key)
|
||||||
|
} else { // Use custom storage
|
||||||
|
if err := cfg.Store.Delete(key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Set response headers from cache
|
||||||
|
c.Response().SetBodyRaw(entry.body)
|
||||||
|
c.Response().SetStatusCode(entry.status)
|
||||||
|
c.Response().Header.SetContentTypeBytes(entry.cType)
|
||||||
|
|
||||||
|
// Set Cache-Control header if enabled
|
||||||
|
if cfg.CacheControl {
|
||||||
|
maxAge := strconv.FormatUint(entry.exp-ts, 10)
|
||||||
|
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return response
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Continue stack, return err to Fiber if exist
|
// Continue stack, return err to Fiber if exist
|
||||||
@ -140,14 +196,26 @@ func New(config ...Config) fiber.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cache response
|
// Cache response
|
||||||
db.Lock()
|
entry.body = c.Response().Body()
|
||||||
db.entries[key] = entry{
|
entry.status = c.Response().StatusCode()
|
||||||
body: c.Response().Body(),
|
entry.cType = c.Response().Header.ContentType()
|
||||||
statusCode: c.Response().StatusCode(),
|
|
||||||
contentType: c.Response().Header.ContentType(),
|
// Use default memory storage
|
||||||
expiration: time.Now().Unix() + db.expiration,
|
if cfg.defaultStore {
|
||||||
|
entries[key] = entry
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Use custom storage
|
||||||
|
data, err := entry.MarshalMsg(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass bytes to Storage
|
||||||
|
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
db.Unlock()
|
|
||||||
|
|
||||||
// Finish response
|
// Finish response
|
||||||
return nil
|
return nil
|
||||||
|
84
middleware/cache/cache_test.go
vendored
84
middleware/cache/cache_test.go
vendored
@ -6,7 +6,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -91,6 +93,55 @@ func Test_Cache(t *testing.T) {
|
|||||||
utils.AssertEqual(t, cachedBody, body)
|
utils.AssertEqual(t, cachedBody, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// go test -run Test_Cache_Concurrency_Store -race -v
|
||||||
|
func Test_Cache_Concurrency_Store(t *testing.T) {
|
||||||
|
// Test concurrency using a custom store
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Use(New(Config{
|
||||||
|
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
|
return c.SendString("Hello tester!")
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
singleRequest := func(wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i <= 49; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go singleRequest(&wg)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||||
|
cachedResp, err := app.Test(cachedReq)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
utils.AssertEqual(t, cachedBody, body)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_Cache_Invalid_Expiration(t *testing.T) {
|
func Test_Cache_Invalid_Expiration(t *testing.T) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
cache := New(Config{Expiration: 0 * time.Second})
|
cache := New(Config{Expiration: 0 * time.Second})
|
||||||
@ -208,3 +259,36 @@ func Benchmark_Cache(b *testing.B) {
|
|||||||
|
|
||||||
utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
|
utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testStore is used for testing custom stores
|
||||||
|
type testStore struct {
|
||||||
|
stmap map[string][]byte
|
||||||
|
mutex *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s testStore) Get(id string) ([]byte, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
val, ok := s.stmap[id]
|
||||||
|
s.mutex.Unlock()
|
||||||
|
if !ok {
|
||||||
|
return []byte{}, nil
|
||||||
|
} else {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
|
||||||
|
s.mutex.Lock()
|
||||||
|
s.stmap[id] = val
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s testStore) Clear() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s testStore) Delete(id string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
12
middleware/cache/store.go
vendored
Normal file
12
middleware/cache/store.go
vendored
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
// go:generate msgp
|
||||||
|
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
|
||||||
|
// don't forget to replace the msgp import path to:
|
||||||
|
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||||
|
type entry struct {
|
||||||
|
body []byte `msg:"body"`
|
||||||
|
cType []byte `msg:"cType"`
|
||||||
|
status int `msg:"status"`
|
||||||
|
exp uint64 `msg:"exp"`
|
||||||
|
}
|
185
middleware/cache/store_msgp.go
vendored
Normal file
185
middleware/cache/store_msgp.go
vendored
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DecodeMsg implements msgp.Decodable
|
||||||
|
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||||
|
var field []byte
|
||||||
|
_ = field
|
||||||
|
var zb0001 uint32
|
||||||
|
zb0001, err = dc.ReadMapHeader()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for zb0001 > 0 {
|
||||||
|
zb0001--
|
||||||
|
field, err = dc.ReadMapKeyPtr()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msgp.UnsafeString(field) {
|
||||||
|
case "body":
|
||||||
|
z.body, err = dc.ReadBytes(z.body)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "cType":
|
||||||
|
z.cType, err = dc.ReadBytes(z.cType)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "cType")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "status":
|
||||||
|
z.status, err = dc.ReadInt()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "exp":
|
||||||
|
z.exp, err = dc.ReadUint64()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "exp")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = dc.Skip()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeMsg implements msgp.Encodable
|
||||||
|
func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||||
|
// map header, size 4
|
||||||
|
// write "body"
|
||||||
|
err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = en.WriteBytes(z.body)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// write "cType"
|
||||||
|
err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = en.WriteBytes(z.cType)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "cType")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// write "status"
|
||||||
|
err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = en.WriteInt(z.status)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// write "exp"
|
||||||
|
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = en.WriteUint64(z.exp)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "exp")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalMsg implements msgp.Marshaler
|
||||||
|
func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||||
|
o = msgp.Require(b, z.Msgsize())
|
||||||
|
// map header, size 4
|
||||||
|
// string "body"
|
||||||
|
o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||||
|
o = msgp.AppendBytes(o, z.body)
|
||||||
|
// string "cType"
|
||||||
|
o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
|
||||||
|
o = msgp.AppendBytes(o, z.cType)
|
||||||
|
// string "status"
|
||||||
|
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
|
||||||
|
o = msgp.AppendInt(o, z.status)
|
||||||
|
// string "exp"
|
||||||
|
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||||
|
o = msgp.AppendUint64(o, z.exp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalMsg implements msgp.Unmarshaler
|
||||||
|
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||||
|
var field []byte
|
||||||
|
_ = field
|
||||||
|
var zb0001 uint32
|
||||||
|
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for zb0001 > 0 {
|
||||||
|
zb0001--
|
||||||
|
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msgp.UnsafeString(field) {
|
||||||
|
case "body":
|
||||||
|
z.body, bts, err = msgp.ReadBytesBytes(bts, z.body)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "cType":
|
||||||
|
z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "cType")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "status":
|
||||||
|
z.status, bts, err = msgp.ReadIntBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "exp":
|
||||||
|
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "exp")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
bts, err = msgp.Skip(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o = bts
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||||
|
func (z *entry) Msgsize() (s int) {
|
||||||
|
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||||
|
return
|
||||||
|
}
|
@ -119,10 +119,10 @@ func New(config ...Config) fiber.Handler {
|
|||||||
max = strconv.Itoa(cfg.Max)
|
max = strconv.Itoa(cfg.Max)
|
||||||
timestamp = uint64(time.Now().Unix())
|
timestamp = uint64(time.Now().Unix())
|
||||||
expiration = uint64(cfg.Expiration.Seconds())
|
expiration = uint64(cfg.Expiration.Seconds())
|
||||||
|
mux = &sync.RWMutex{}
|
||||||
|
|
||||||
// Default store logic (if no Store is provided)
|
// Default store logic (if no Store is provided)
|
||||||
data = make(map[string]Entry)
|
entries = make(map[string]entry)
|
||||||
mux = &sync.RWMutex{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Update timestamp every second
|
// Update timestamp every second
|
||||||
@ -140,20 +140,20 @@ func New(config ...Config) fiber.Handler {
|
|||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get key (default is the remote IP)
|
// Get key from request
|
||||||
key := cfg.Key(c)
|
key := cfg.Key(c)
|
||||||
|
|
||||||
// Create new entry
|
// Create new entry
|
||||||
entry := Entry{}
|
entry := entry{}
|
||||||
|
|
||||||
// Lock entry
|
// Lock entry
|
||||||
mux.Lock()
|
mux.Lock()
|
||||||
|
defer mux.Unlock()
|
||||||
|
|
||||||
// Check if we need to use the default in-memory storage
|
// Use default memory storage
|
||||||
if cfg.defaultStore {
|
if cfg.defaultStore {
|
||||||
entry = data[key]
|
entry = entries[key]
|
||||||
} else {
|
} else { // Use custom storage
|
||||||
// Load data from store
|
|
||||||
storeEntry, err := cfg.Store.Get(key)
|
storeEntry, err := cfg.Store.Get(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -167,23 +167,26 @@ func New(config ...Config) fiber.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set unix timestamp if not exist
|
// Get timestamp
|
||||||
ts := atomic.LoadUint64(×tamp)
|
ts := atomic.LoadUint64(×tamp)
|
||||||
if entry.Exp == 0 {
|
|
||||||
entry.Exp = ts + expiration
|
// Set expiration if entry does not exist
|
||||||
} else if ts >= entry.Exp {
|
if entry.exp == 0 {
|
||||||
entry.Hits = 0
|
entry.exp = ts + expiration
|
||||||
entry.Exp = ts + expiration
|
|
||||||
|
} else if ts >= entry.exp {
|
||||||
|
// Check if entry is expired
|
||||||
|
entry.hits = 0
|
||||||
|
entry.exp = ts + expiration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment hits
|
// Increment hits
|
||||||
entry.Hits++
|
entry.hits++
|
||||||
|
|
||||||
// Check if we need to use the default in-memory storage
|
// Use default memory storage
|
||||||
if cfg.defaultStore {
|
if cfg.defaultStore {
|
||||||
data[key] = entry
|
entries[key] = entry
|
||||||
} else {
|
} else { // Use custom storage
|
||||||
// Encode Entry to bytes using msgp
|
|
||||||
data, err := entry.MarshalMsg(nil)
|
data, err := entry.MarshalMsg(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -195,13 +198,11 @@ func New(config ...Config) fiber.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mux.Unlock()
|
|
||||||
|
|
||||||
// Calculate when it resets in seconds
|
// Calculate when it resets in seconds
|
||||||
expire := entry.Exp - ts
|
expire := entry.exp - ts
|
||||||
|
|
||||||
// Set how many hits we have left
|
// Set how many hits we have left
|
||||||
remaining := cfg.Max - entry.Hits
|
remaining := cfg.Max - entry.hits
|
||||||
|
|
||||||
// Check if hits exceed the cfg.Max
|
// Check if hits exceed the cfg.Max
|
||||||
if remaining < 0 {
|
if remaining < 0 {
|
||||||
|
@ -116,7 +116,7 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -v -run=^$ -bench=Benchmark_Limiter_Benchmark -benchmem -count=4
|
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
||||||
func Benchmark_Limiter(b *testing.B) {
|
func Benchmark_Limiter(b *testing.B) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
package limiter
|
package limiter
|
||||||
|
|
||||||
//go:generate msgp -o=store_msgp.go -tests=false -file=store.go
|
// go:generate msgp
|
||||||
type Entry struct {
|
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
|
||||||
Hits int
|
// don't forget to replace the msgp import path to:
|
||||||
Exp uint64
|
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||||
|
type entry struct {
|
||||||
|
hits int
|
||||||
|
exp uint64
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// DecodeMsg implements msgp.Decodable
|
// DecodeMsg implements msgp.Decodable
|
||||||
func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||||
var field []byte
|
var field []byte
|
||||||
_ = field
|
_ = field
|
||||||
var zb0001 uint32
|
var zb0001 uint32
|
||||||
@ -24,16 +24,16 @@ func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch msgp.UnsafeString(field) {
|
switch msgp.UnsafeString(field) {
|
||||||
case "Hits":
|
case "hits":
|
||||||
z.Hits, err = dc.ReadInt()
|
z.hits, err = dc.ReadInt()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = msgp.WrapError(err, "Hits")
|
err = msgp.WrapError(err, "hits")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "Exp":
|
case "exp":
|
||||||
z.Exp, err = dc.ReadUint64()
|
z.exp, err = dc.ReadUint64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = msgp.WrapError(err, "Exp")
|
err = msgp.WrapError(err, "exp")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -48,46 +48,46 @@ func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EncodeMsg implements msgp.Encodable
|
// EncodeMsg implements msgp.Encodable
|
||||||
func (z Entry) EncodeMsg(en *msgp.Writer) (err error) {
|
func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||||
// map header, size 2
|
// map header, size 2
|
||||||
// write "Hits"
|
// write "hits"
|
||||||
err = en.Append(0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
|
err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = en.WriteInt(z.Hits)
|
err = en.WriteInt(z.hits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = msgp.WrapError(err, "Hits")
|
err = msgp.WrapError(err, "hits")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// write "Exp"
|
// write "exp"
|
||||||
err = en.Append(0xa3, 0x45, 0x78, 0x70)
|
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = en.WriteUint64(z.Exp)
|
err = en.WriteUint64(z.exp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = msgp.WrapError(err, "Exp")
|
err = msgp.WrapError(err, "exp")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalMsg implements msgp.Marshaler
|
// MarshalMsg implements msgp.Marshaler
|
||||||
func (z Entry) MarshalMsg(b []byte) (o []byte, err error) {
|
func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||||
o = msgp.Require(b, z.Msgsize())
|
o = msgp.Require(b, z.Msgsize())
|
||||||
// map header, size 2
|
// map header, size 2
|
||||||
// string "Hits"
|
// string "hits"
|
||||||
o = append(o, 0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
|
o = append(o, 0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
|
||||||
o = msgp.AppendInt(o, z.Hits)
|
o = msgp.AppendInt(o, z.hits)
|
||||||
// string "Exp"
|
// string "exp"
|
||||||
o = append(o, 0xa3, 0x45, 0x78, 0x70)
|
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||||
o = msgp.AppendUint64(o, z.Exp)
|
o = msgp.AppendUint64(o, z.exp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalMsg implements msgp.Unmarshaler
|
// UnmarshalMsg implements msgp.Unmarshaler
|
||||||
func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||||
var field []byte
|
var field []byte
|
||||||
_ = field
|
_ = field
|
||||||
var zb0001 uint32
|
var zb0001 uint32
|
||||||
@ -104,16 +104,16 @@ func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch msgp.UnsafeString(field) {
|
switch msgp.UnsafeString(field) {
|
||||||
case "Hits":
|
case "hits":
|
||||||
z.Hits, bts, err = msgp.ReadIntBytes(bts)
|
z.hits, bts, err = msgp.ReadIntBytes(bts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = msgp.WrapError(err, "Hits")
|
err = msgp.WrapError(err, "hits")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "Exp":
|
case "exp":
|
||||||
z.Exp, bts, err = msgp.ReadUint64Bytes(bts)
|
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = msgp.WrapError(err, "Exp")
|
err = msgp.WrapError(err, "exp")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -129,7 +129,7 @@ func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||||
func (z Entry) Msgsize() (s int) {
|
func (z entry) Msgsize() (s int) {
|
||||||
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
|
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user