diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 209cd8d9..268d0262 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -5,6 +5,7 @@ package cache import ( "strconv" "sync" + "sync/atomic" "time" "github.com/gofiber/fiber/v2" @@ -26,6 +27,15 @@ type Config struct { // // Optional. Default: false CacheControl bool + + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Store fiber.Storage + + // Internally used - if true, the simpler method of two maps is used in order to keep + // execution time down. + defaultStore bool } // ConfigDefault is the default config @@ -33,6 +43,7 @@ var ConfigDefault = Config{ Next: nil, Expiration: 1 * time.Minute, CacheControl: false, + defaultStore: true, } // cache is the manager to store the cached responses @@ -42,14 +53,6 @@ type cache struct { expiration int64 } -// entry defines the cached response -type entry struct { - body []byte - contentType []byte - statusCode int - expiration int64 -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config @@ -66,8 +69,29 @@ func New(config ...Config) fiber.Handler { if int(cfg.Expiration.Seconds()) == 0 { cfg.Expiration = ConfigDefault.Expiration } + if cfg.Store == nil { + cfg.defaultStore = true + } } + var ( + // Cache settings + timestamp = uint64(time.Now().Unix()) + expiration = uint64(cfg.Expiration.Seconds()) + mux = &sync.RWMutex{} + + // Default store logic (if no Store is provided) + entries = make(map[string]entry) + ) + + // Update timestamp every second + go func() { + for { + atomic.StoreUint64(×tamp, uint64(time.Now().Unix())) + time.Sleep(1 * time.Second) + } + }() + // Nothing to cache if int(cfg.Expiration.Seconds()) < 0 { return func(c *fiber.Ctx) error { @@ -75,23 +99,18 @@ func New(config ...Config) fiber.Handler { } } - // Initialize db - db := &cache{ - entries: make(map[string]entry), - expiration: int64(cfg.Expiration.Seconds()), - } // Remove expired entries go func() { for { - // GC the entries every 10 seconds to avoid + // GC the entries every 10 seconds time.Sleep(10 * time.Second) - db.Lock() - for k := range db.entries { - if time.Now().Unix() >= db.entries[k].expiration { - delete(db.entries, k) + mux.Lock() + for k := range entries { + if atomic.LoadUint64(×tamp) >= entries[k].exp { + delete(entries, k) } } - db.Unlock() + mux.Unlock() } }() @@ -110,28 +129,65 @@ func New(config ...Config) fiber.Handler { // Get key from request key := c.Path() - // Find cached entry - db.RLock() - resp, ok := db.entries[key] - db.RUnlock() - if ok { - // Check if entry is expired - if time.Now().Unix() >= resp.expiration { - db.Lock() - delete(db.entries, key) - db.Unlock() - } else { - // Set response headers from cache - c.Response().SetBodyRaw(resp.body) - c.Response().SetStatusCode(resp.statusCode) - c.Response().Header.SetContentTypeBytes(resp.contentType) - // Set Cache-Control header if enabled - if cfg.CacheControl { - maxAge := strconv.FormatInt(resp.expiration-time.Now().Unix(), 10) - c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge) - } - return nil + // Create new entry + entry := entry{} + + // Lock entry + mux.Lock() + defer mux.Unlock() + + // Check if we need to use the default in-memory storage + if cfg.defaultStore { + entry = entries[key] + + } else { + // Load data from store + storeEntry, err := cfg.Store.Get(key) + if err != nil { + return err } + + // Only decode if we found an entry + if len(storeEntry) > 0 { + // Decode bytes using msgp + if _, err := entry.UnmarshalMsg(storeEntry); err != nil { + return err + } + } + } + + // Get timestamp + ts := atomic.LoadUint64(×tamp) + + // Set expiration if entry does not exist + if entry.exp == 0 { + entry.exp = ts + expiration + + } else if ts >= entry.exp { + // Check if entry is expired + // Use default memory storage + if cfg.defaultStore { + delete(entries, key) + } else { // Use custom storage + if err := cfg.Store.Delete(key); err != nil { + return err + } + } + + } else { + // Set response headers from cache + c.Response().SetBodyRaw(entry.body) + c.Response().SetStatusCode(entry.status) + c.Response().Header.SetContentTypeBytes(entry.cType) + + // Set Cache-Control header if enabled + if cfg.CacheControl { + maxAge := strconv.FormatUint(entry.exp-ts, 10) + c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge) + } + + // Return response + return nil } // Continue stack, return err to Fiber if exist @@ -140,14 +196,26 @@ func New(config ...Config) fiber.Handler { } // Cache response - db.Lock() - db.entries[key] = entry{ - body: c.Response().Body(), - statusCode: c.Response().StatusCode(), - contentType: c.Response().Header.ContentType(), - expiration: time.Now().Unix() + db.expiration, + entry.body = c.Response().Body() + entry.status = c.Response().StatusCode() + entry.cType = c.Response().Header.ContentType() + + // Use default memory storage + if cfg.defaultStore { + entries[key] = entry + + } else { + // Use custom storage + data, err := entry.MarshalMsg(nil) + if err != nil { + return err + } + + // Pass bytes to Storage + if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil { + return err + } } - db.Unlock() // Finish response return nil diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index da102759..d722de58 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -6,7 +6,9 @@ import ( "bytes" "fmt" "io/ioutil" + "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -91,6 +93,55 @@ func Test_Cache(t *testing.T) { utils.AssertEqual(t, cachedBody, body) } +// go test -run Test_Cache_Concurrency_Store -race -v +func Test_Cache_Concurrency_Store(t *testing.T) { + // Test concurrency using a custom store + + app := fiber.New() + + app.Use(New(Config{ + Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)}, + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello tester!") + }) + + var wg sync.WaitGroup + singleRequest := func(wg *sync.WaitGroup) { + defer wg.Done() + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "Hello tester!", string(body)) + } + + for i := 0; i <= 49; i++ { + wg.Add(1) + go singleRequest(&wg) + } + + wg.Wait() + + req := httptest.NewRequest("GET", "/", nil) + resp, err := app.Test(req) + utils.AssertEqual(t, nil, err) + + cachedReq := httptest.NewRequest("GET", "/", nil) + cachedResp, err := app.Test(cachedReq) + utils.AssertEqual(t, nil, err) + + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + cachedBody, err := ioutil.ReadAll(cachedResp.Body) + utils.AssertEqual(t, nil, err) + + utils.AssertEqual(t, cachedBody, body) +} + func Test_Cache_Invalid_Expiration(t *testing.T) { app := fiber.New() cache := New(Config{Expiration: 0 * time.Second}) @@ -208,3 +259,36 @@ func Benchmark_Cache(b *testing.B) { utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode()) } + +// testStore is used for testing custom stores +type testStore struct { + stmap map[string][]byte + mutex *sync.Mutex +} + +func (s testStore) Get(id string) ([]byte, error) { + s.mutex.Lock() + val, ok := s.stmap[id] + s.mutex.Unlock() + if !ok { + return []byte{}, nil + } else { + return val, nil + } +} + +func (s testStore) Set(id string, val []byte, _ time.Duration) error { + s.mutex.Lock() + s.stmap[id] = val + s.mutex.Unlock() + + return nil +} + +func (s testStore) Clear() error { + return nil +} + +func (s testStore) Delete(id string) error { + return nil +} diff --git a/middleware/cache/store.go b/middleware/cache/store.go new file mode 100644 index 00000000..92bf0935 --- /dev/null +++ b/middleware/cache/store.go @@ -0,0 +1,12 @@ +package cache + +// go:generate msgp +// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported +// don't forget to replace the msgp import path to: +// "github.com/gofiber/fiber/v2/internal/msgp" +type entry struct { + body []byte `msg:"body"` + cType []byte `msg:"cType"` + status int `msg:"status"` + exp uint64 `msg:"exp"` +} diff --git a/middleware/cache/store_msgp.go b/middleware/cache/store_msgp.go new file mode 100644 index 00000000..a5b7a984 --- /dev/null +++ b/middleware/cache/store_msgp.go @@ -0,0 +1,185 @@ +package cache + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/gofiber/fiber/v2/internal/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "body": + z.body, err = dc.ReadBytes(z.body) + if err != nil { + err = msgp.WrapError(err, "body") + return + } + case "cType": + z.cType, err = dc.ReadBytes(z.cType) + if err != nil { + err = msgp.WrapError(err, "cType") + return + } + case "status": + z.status, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "status") + return + } + case "exp": + z.exp, err = dc.ReadUint64() + if err != nil { + err = msgp.WrapError(err, "exp") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *entry) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 4 + // write "body" + err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79) + if err != nil { + return + } + err = en.WriteBytes(z.body) + if err != nil { + err = msgp.WrapError(err, "body") + return + } + // write "cType" + err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65) + if err != nil { + return + } + err = en.WriteBytes(z.cType) + if err != nil { + err = msgp.WrapError(err, "cType") + return + } + // write "status" + err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73) + if err != nil { + return + } + err = en.WriteInt(z.status) + if err != nil { + err = msgp.WrapError(err, "status") + return + } + // write "exp" + err = en.Append(0xa3, 0x65, 0x78, 0x70) + if err != nil { + return + } + err = en.WriteUint64(z.exp) + if err != nil { + err = msgp.WrapError(err, "exp") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *entry) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 4 + // string "body" + o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79) + o = msgp.AppendBytes(o, z.body) + // string "cType" + o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65) + o = msgp.AppendBytes(o, z.cType) + // string "status" + o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73) + o = msgp.AppendInt(o, z.status) + // string "exp" + o = append(o, 0xa3, 0x65, 0x78, 0x70) + o = msgp.AppendUint64(o, z.exp) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "body": + z.body, bts, err = msgp.ReadBytesBytes(bts, z.body) + if err != nil { + err = msgp.WrapError(err, "body") + return + } + case "cType": + z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType) + if err != nil { + err = msgp.WrapError(err, "cType") + return + } + case "status": + z.status, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "status") + return + } + case "exp": + z.exp, bts, err = msgp.ReadUint64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "exp") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *entry) Msgsize() (s int) { + s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + return +} diff --git a/middleware/limiter/limiter.go b/middleware/limiter/limiter.go index 1ed67b20..6044d87c 100644 --- a/middleware/limiter/limiter.go +++ b/middleware/limiter/limiter.go @@ -119,10 +119,10 @@ func New(config ...Config) fiber.Handler { max = strconv.Itoa(cfg.Max) timestamp = uint64(time.Now().Unix()) expiration = uint64(cfg.Expiration.Seconds()) + mux = &sync.RWMutex{} // Default store logic (if no Store is provided) - data = make(map[string]Entry) - mux = &sync.RWMutex{} + entries = make(map[string]entry) ) // Update timestamp every second @@ -140,20 +140,20 @@ func New(config ...Config) fiber.Handler { return c.Next() } - // Get key (default is the remote IP) + // Get key from request key := cfg.Key(c) // Create new entry - entry := Entry{} + entry := entry{} // Lock entry mux.Lock() + defer mux.Unlock() - // Check if we need to use the default in-memory storage + // Use default memory storage if cfg.defaultStore { - entry = data[key] - } else { - // Load data from store + entry = entries[key] + } else { // Use custom storage storeEntry, err := cfg.Store.Get(key) if err != nil { return err @@ -167,23 +167,26 @@ func New(config ...Config) fiber.Handler { } } - // Set unix timestamp if not exist + // Get timestamp ts := atomic.LoadUint64(×tamp) - if entry.Exp == 0 { - entry.Exp = ts + expiration - } else if ts >= entry.Exp { - entry.Hits = 0 - entry.Exp = ts + expiration + + // Set expiration if entry does not exist + if entry.exp == 0 { + entry.exp = ts + expiration + + } else if ts >= entry.exp { + // Check if entry is expired + entry.hits = 0 + entry.exp = ts + expiration } // Increment hits - entry.Hits++ + entry.hits++ - // Check if we need to use the default in-memory storage + // Use default memory storage if cfg.defaultStore { - data[key] = entry - } else { - // Encode Entry to bytes using msgp + entries[key] = entry + } else { // Use custom storage data, err := entry.MarshalMsg(nil) if err != nil { return err @@ -195,13 +198,11 @@ func New(config ...Config) fiber.Handler { } } - mux.Unlock() - // Calculate when it resets in seconds - expire := entry.Exp - ts + expire := entry.exp - ts // Set how many hits we have left - remaining := cfg.Max - entry.Hits + remaining := cfg.Max - entry.hits // Check if hits exceed the cfg.Max if remaining < 0 { diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index 7f5f8b54..a26a493e 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -116,7 +116,7 @@ func Test_Limiter_Concurrency(t *testing.T) { } -// go test -v -run=^$ -bench=Benchmark_Limiter_Benchmark -benchmem -count=4 +// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4 func Benchmark_Limiter(b *testing.B) { app := fiber.New() diff --git a/middleware/limiter/store.go b/middleware/limiter/store.go index 1466a8bf..65e24d9a 100644 --- a/middleware/limiter/store.go +++ b/middleware/limiter/store.go @@ -1,7 +1,10 @@ package limiter -//go:generate msgp -o=store_msgp.go -tests=false -file=store.go -type Entry struct { - Hits int - Exp uint64 +// go:generate msgp +// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported +// don't forget to replace the msgp import path to: +// "github.com/gofiber/fiber/v2/internal/msgp" +type entry struct { + hits int + exp uint64 } diff --git a/middleware/limiter/store_msgp.go b/middleware/limiter/store_msgp.go index d29b22a9..15d7d592 100644 --- a/middleware/limiter/store_msgp.go +++ b/middleware/limiter/store_msgp.go @@ -7,7 +7,7 @@ import ( ) // DecodeMsg implements msgp.Decodable -func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) { +func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 @@ -24,16 +24,16 @@ func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) { return } switch msgp.UnsafeString(field) { - case "Hits": - z.Hits, err = dc.ReadInt() + case "hits": + z.hits, err = dc.ReadInt() if err != nil { - err = msgp.WrapError(err, "Hits") + err = msgp.WrapError(err, "hits") return } - case "Exp": - z.Exp, err = dc.ReadUint64() + case "exp": + z.exp, err = dc.ReadUint64() if err != nil { - err = msgp.WrapError(err, "Exp") + err = msgp.WrapError(err, "exp") return } default: @@ -48,46 +48,46 @@ func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) { } // EncodeMsg implements msgp.Encodable -func (z Entry) EncodeMsg(en *msgp.Writer) (err error) { +func (z entry) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 2 - // write "Hits" - err = en.Append(0x82, 0xa4, 0x48, 0x69, 0x74, 0x73) + // write "hits" + err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73) if err != nil { return } - err = en.WriteInt(z.Hits) + err = en.WriteInt(z.hits) if err != nil { - err = msgp.WrapError(err, "Hits") + err = msgp.WrapError(err, "hits") return } - // write "Exp" - err = en.Append(0xa3, 0x45, 0x78, 0x70) + // write "exp" + err = en.Append(0xa3, 0x65, 0x78, 0x70) if err != nil { return } - err = en.WriteUint64(z.Exp) + err = en.WriteUint64(z.exp) if err != nil { - err = msgp.WrapError(err, "Exp") + err = msgp.WrapError(err, "exp") return } return } // MarshalMsg implements msgp.Marshaler -func (z Entry) MarshalMsg(b []byte) (o []byte, err error) { +func (z entry) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 2 - // string "Hits" - o = append(o, 0x82, 0xa4, 0x48, 0x69, 0x74, 0x73) - o = msgp.AppendInt(o, z.Hits) - // string "Exp" - o = append(o, 0xa3, 0x45, 0x78, 0x70) - o = msgp.AppendUint64(o, z.Exp) + // string "hits" + o = append(o, 0x82, 0xa4, 0x68, 0x69, 0x74, 0x73) + o = msgp.AppendInt(o, z.hits) + // string "exp" + o = append(o, 0xa3, 0x65, 0x78, 0x70) + o = msgp.AppendUint64(o, z.exp) return } // UnmarshalMsg implements msgp.Unmarshaler -func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) { +func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 @@ -104,16 +104,16 @@ func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) { return } switch msgp.UnsafeString(field) { - case "Hits": - z.Hits, bts, err = msgp.ReadIntBytes(bts) + case "hits": + z.hits, bts, err = msgp.ReadIntBytes(bts) if err != nil { - err = msgp.WrapError(err, "Hits") + err = msgp.WrapError(err, "hits") return } - case "Exp": - z.Exp, bts, err = msgp.ReadUint64Bytes(bts) + case "exp": + z.exp, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { - err = msgp.WrapError(err, "Exp") + err = msgp.WrapError(err, "exp") return } default: @@ -129,7 +129,7 @@ func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) { } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -func (z Entry) Msgsize() (s int) { +func (z entry) Msgsize() (s int) { s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size return }