💼 implement Storage

This commit is contained in:
Fenny 2020-10-28 02:29:47 +01:00
parent 32fdbf0ddf
commit ecdda95e15
8 changed files with 459 additions and 106 deletions

View File

@ -5,6 +5,7 @@ package cache
import (
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
@ -26,6 +27,15 @@ type Config struct {
//
// Optional. Default: false
CacheControl bool
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store fiber.Storage
// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
defaultStore bool
}
// ConfigDefault is the default config
@ -33,6 +43,7 @@ var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
defaultStore: true,
}
// cache is the manager to store the cached responses
@ -42,14 +53,6 @@ type cache struct {
expiration int64
}
// entry defines the cached response
type entry struct {
body []byte
contentType []byte
statusCode int
expiration int64
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
@ -66,8 +69,29 @@ func New(config ...Config) fiber.Handler {
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.Store == nil {
cfg.defaultStore = true
}
}
var (
// Cache settings
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
mux = &sync.RWMutex{}
// Default store logic (if no Store is provided)
entries = make(map[string]entry)
)
// Update timestamp every second
go func() {
for {
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(1 * time.Second)
}
}()
// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c *fiber.Ctx) error {
@ -75,23 +99,18 @@ func New(config ...Config) fiber.Handler {
}
}
// Initialize db
db := &cache{
entries: make(map[string]entry),
expiration: int64(cfg.Expiration.Seconds()),
}
// Remove expired entries
go func() {
for {
// GC the entries every 10 seconds to avoid
// GC the entries every 10 seconds
time.Sleep(10 * time.Second)
db.Lock()
for k := range db.entries {
if time.Now().Unix() >= db.entries[k].expiration {
delete(db.entries, k)
mux.Lock()
for k := range entries {
if atomic.LoadUint64(&timestamp) >= entries[k].exp {
delete(entries, k)
}
}
db.Unlock()
mux.Unlock()
}
}()
@ -110,28 +129,65 @@ func New(config ...Config) fiber.Handler {
// Get key from request
key := c.Path()
// Find cached entry
db.RLock()
resp, ok := db.entries[key]
db.RUnlock()
if ok {
// Check if entry is expired
if time.Now().Unix() >= resp.expiration {
db.Lock()
delete(db.entries, key)
db.Unlock()
} else {
// Set response headers from cache
c.Response().SetBodyRaw(resp.body)
c.Response().SetStatusCode(resp.statusCode)
c.Response().Header.SetContentTypeBytes(resp.contentType)
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatInt(resp.expiration-time.Now().Unix(), 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
return nil
// Create new entry
entry := entry{}
// Lock entry
mux.Lock()
defer mux.Unlock()
// Check if we need to use the default in-memory storage
if cfg.defaultStore {
entry = entries[key]
} else {
// Load data from store
storeEntry, err := cfg.Store.Get(key)
if err != nil {
return err
}
// Only decode if we found an entry
if len(storeEntry) > 0 {
// Decode bytes using msgp
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
return err
}
}
}
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Set expiration if entry does not exist
if entry.exp == 0 {
entry.exp = ts + expiration
} else if ts >= entry.exp {
// Check if entry is expired
// Use default memory storage
if cfg.defaultStore {
delete(entries, key)
} else { // Use custom storage
if err := cfg.Store.Delete(key); err != nil {
return err
}
}
} else {
// Set response headers from cache
c.Response().SetBodyRaw(entry.body)
c.Response().SetStatusCode(entry.status)
c.Response().Header.SetContentTypeBytes(entry.cType)
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(entry.exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
// Return response
return nil
}
// Continue stack, return err to Fiber if exist
@ -140,14 +196,26 @@ func New(config ...Config) fiber.Handler {
}
// Cache response
db.Lock()
db.entries[key] = entry{
body: c.Response().Body(),
statusCode: c.Response().StatusCode(),
contentType: c.Response().Header.ContentType(),
expiration: time.Now().Unix() + db.expiration,
entry.body = c.Response().Body()
entry.status = c.Response().StatusCode()
entry.cType = c.Response().Header.ContentType()
// Use default memory storage
if cfg.defaultStore {
entries[key] = entry
} else {
// Use custom storage
data, err := entry.MarshalMsg(nil)
if err != nil {
return err
}
// Pass bytes to Storage
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
return err
}
}
db.Unlock()
// Finish response
return nil

View File

@ -6,7 +6,9 @@ import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@ -91,6 +93,55 @@ func Test_Cache(t *testing.T) {
utils.AssertEqual(t, cachedBody, body)
}
// go test -run Test_Cache_Concurrency_Store -race -v
func Test_Cache_Concurrency_Store(t *testing.T) {
// Test concurrency using a custom store
app := fiber.New()
app.Use(New(Config{
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
wg.Wait()
req := httptest.NewRequest("GET", "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest("GET", "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cachedBody, body)
}
func Test_Cache_Invalid_Expiration(t *testing.T) {
app := fiber.New()
cache := New(Config{Expiration: 0 * time.Second})
@ -208,3 +259,36 @@ func Benchmark_Cache(b *testing.B) {
utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
}
// testStore is used for testing custom stores
type testStore struct {
stmap map[string][]byte
mutex *sync.Mutex
}
func (s testStore) Get(id string) ([]byte, error) {
s.mutex.Lock()
val, ok := s.stmap[id]
s.mutex.Unlock()
if !ok {
return []byte{}, nil
} else {
return val, nil
}
}
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()
return nil
}
func (s testStore) Clear() error {
return nil
}
func (s testStore) Delete(id string) error {
return nil
}

12
middleware/cache/store.go vendored Normal file
View File

@ -0,0 +1,12 @@
package cache
// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
body []byte `msg:"body"`
cType []byte `msg:"cType"`
status int `msg:"status"`
exp uint64 `msg:"exp"`
}

185
middleware/cache/store_msgp.go vendored Normal file
View File

@ -0,0 +1,185 @@
package cache
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/gofiber/fiber/v2/internal/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "body":
z.body, err = dc.ReadBytes(z.body)
if err != nil {
err = msgp.WrapError(err, "body")
return
}
case "cType":
z.cType, err = dc.ReadBytes(z.cType)
if err != nil {
err = msgp.WrapError(err, "cType")
return
}
case "status":
z.status, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "status")
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 4
// write "body"
err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
if err != nil {
return
}
err = en.WriteBytes(z.body)
if err != nil {
err = msgp.WrapError(err, "body")
return
}
// write "cType"
err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
if err != nil {
return
}
err = en.WriteBytes(z.cType)
if err != nil {
err = msgp.WrapError(err, "cType")
return
}
// write "status"
err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.status)
if err != nil {
err = msgp.WrapError(err, "status")
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 4
// string "body"
o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
o = msgp.AppendBytes(o, z.body)
// string "cType"
o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.cType)
// string "status"
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
o = msgp.AppendInt(o, z.status)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "body":
z.body, bts, err = msgp.ReadBytesBytes(bts, z.body)
if err != nil {
err = msgp.WrapError(err, "body")
return
}
case "cType":
z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType)
if err != nil {
err = msgp.WrapError(err, "cType")
return
}
case "status":
z.status, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "status")
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *entry) Msgsize() (s int) {
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
return
}

View File

@ -119,10 +119,10 @@ func New(config ...Config) fiber.Handler {
max = strconv.Itoa(cfg.Max)
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
mux = &sync.RWMutex{}
// Default store logic (if no Store is provided)
data = make(map[string]Entry)
mux = &sync.RWMutex{}
entries = make(map[string]entry)
)
// Update timestamp every second
@ -140,20 +140,20 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
// Get key (default is the remote IP)
// Get key from request
key := cfg.Key(c)
// Create new entry
entry := Entry{}
entry := entry{}
// Lock entry
mux.Lock()
defer mux.Unlock()
// Check if we need to use the default in-memory storage
// Use default memory storage
if cfg.defaultStore {
entry = data[key]
} else {
// Load data from store
entry = entries[key]
} else { // Use custom storage
storeEntry, err := cfg.Store.Get(key)
if err != nil {
return err
@ -167,23 +167,26 @@ func New(config ...Config) fiber.Handler {
}
}
// Set unix timestamp if not exist
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
if entry.Exp == 0 {
entry.Exp = ts + expiration
} else if ts >= entry.Exp {
entry.Hits = 0
entry.Exp = ts + expiration
// Set expiration if entry does not exist
if entry.exp == 0 {
entry.exp = ts + expiration
} else if ts >= entry.exp {
// Check if entry is expired
entry.hits = 0
entry.exp = ts + expiration
}
// Increment hits
entry.Hits++
entry.hits++
// Check if we need to use the default in-memory storage
// Use default memory storage
if cfg.defaultStore {
data[key] = entry
} else {
// Encode Entry to bytes using msgp
entries[key] = entry
} else { // Use custom storage
data, err := entry.MarshalMsg(nil)
if err != nil {
return err
@ -195,13 +198,11 @@ func New(config ...Config) fiber.Handler {
}
}
mux.Unlock()
// Calculate when it resets in seconds
expire := entry.Exp - ts
expire := entry.exp - ts
// Set how many hits we have left
remaining := cfg.Max - entry.Hits
remaining := cfg.Max - entry.hits
// Check if hits exceed the cfg.Max
if remaining < 0 {

View File

@ -116,7 +116,7 @@ func Test_Limiter_Concurrency(t *testing.T) {
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Benchmark -benchmem -count=4
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
func Benchmark_Limiter(b *testing.B) {
app := fiber.New()

View File

@ -1,7 +1,10 @@
package limiter
//go:generate msgp -o=store_msgp.go -tests=false -file=store.go
type Entry struct {
Hits int
Exp uint64
// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
hits int
exp uint64
}

View File

@ -7,7 +7,7 @@ import (
)
// DecodeMsg implements msgp.Decodable
func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) {
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
@ -24,16 +24,16 @@ func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) {
return
}
switch msgp.UnsafeString(field) {
case "Hits":
z.Hits, err = dc.ReadInt()
case "hits":
z.hits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "Hits")
err = msgp.WrapError(err, "hits")
return
}
case "Exp":
z.Exp, err = dc.ReadUint64()
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "Exp")
err = msgp.WrapError(err, "exp")
return
}
default:
@ -48,46 +48,46 @@ func (z *Entry) DecodeMsg(dc *msgp.Reader) (err error) {
}
// EncodeMsg implements msgp.Encodable
func (z Entry) EncodeMsg(en *msgp.Writer) (err error) {
func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "Hits"
err = en.Append(0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
// write "hits"
err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.Hits)
err = en.WriteInt(z.hits)
if err != nil {
err = msgp.WrapError(err, "Hits")
err = msgp.WrapError(err, "hits")
return
}
// write "Exp"
err = en.Append(0xa3, 0x45, 0x78, 0x70)
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.Exp)
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "Exp")
err = msgp.WrapError(err, "exp")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z Entry) MarshalMsg(b []byte) (o []byte, err error) {
func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "Hits"
o = append(o, 0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.Hits)
// string "Exp"
o = append(o, 0xa3, 0x45, 0x78, 0x70)
o = msgp.AppendUint64(o, z.Exp)
// string "hits"
o = append(o, 0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.hits)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
@ -104,16 +104,16 @@ func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
return
}
switch msgp.UnsafeString(field) {
case "Hits":
z.Hits, bts, err = msgp.ReadIntBytes(bts)
case "hits":
z.hits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Hits")
err = msgp.WrapError(err, "hits")
return
}
case "Exp":
z.Exp, bts, err = msgp.ReadUint64Bytes(bts)
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "Exp")
err = msgp.WrapError(err, "exp")
return
}
default:
@ -129,7 +129,7 @@ func (z *Entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z Entry) Msgsize() (s int) {
func (z entry) Msgsize() (s int) {
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
return
}