🩹 fix manager logic

pull/1025/head
Fenny 2020-11-23 07:38:42 +01:00
parent 8fe458011d
commit 323d9d89cc
17 changed files with 547 additions and 315 deletions

3
go.mod
View File

@ -3,7 +3,6 @@ module github.com/gofiber/fiber/v2
go 1.14
require (
github.com/klauspost/compress v1.11.0 // indirect
github.com/valyala/fasthttp v1.17.0
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68
)

12
go.sum
View File

@ -1,9 +1,11 @@
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/andybalholm/brotli v1.0.1 h1:KqhlKozYbRtJvsPrrEeXcO+N2l6NYT5A2QAFmSULpEc=
github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg=
github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.3 h1:dB4Bn0tN3wdCzQxnS8r06kV74qN/TAfaIS0bVE8h3jc=
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg=
@ -14,13 +16,15 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg=
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1 h1:a/mKvvZr9Jcc8oKfcmgzyp7OwF73JPWsQLvH1z2Kxck=
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -1,4 +1,4 @@
package mapstore
package memory
import (
"sync"

View File

@ -1,4 +1,4 @@
package mapstore
package memory
import (
"testing"
@ -7,9 +7,9 @@ import (
"github.com/gofiber/fiber/v2/utils"
)
// go test -run Test_MapStore -v -race
// go test -run Test_Memory -v -race
func Test_MapStore(t *testing.T) {
func Test_Memory(t *testing.T) {
var store = New()
var (
key = "john"
@ -51,8 +51,8 @@ func Test_MapStore(t *testing.T) {
utils.AssertEqual(t, nil, result)
}
// go test -v -run=^$ -bench=Benchmark_MapStore -benchmem -count=4
func Benchmark_MapStore(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Memory -benchmem -count=4
func Benchmark_Memory(b *testing.B) {
keyLength := 1000
keys := make([]string, keyLength)
for i := 0; i < keyLength; i++ {

View File

@ -4,6 +4,7 @@ package cache
import (
"strconv"
"sync"
"sync/atomic"
"time"
@ -25,18 +26,18 @@ func New(config ...Config) fiber.Handler {
var (
// Cache settings
mux = &sync.RWMutex{}
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
)
// create storage handler
store := newStorage(&cfg)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
go func() {
for {
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(750 * time.Millisecond)
time.Sleep(1 * time.Second)
}
}()
@ -55,24 +56,37 @@ func New(config ...Config) fiber.Handler {
// Get key from request
key := cfg.KeyGenerator(c)
// Get/Create new entry
e := store.get(key)
if e == nil {
e = &entry{}
}
// Get entry from pool
e := manager.get(key)
// Lock entry and unlock when finished
mux.Lock()
defer mux.Unlock()
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Set expiration if entry does not exist
if e.exp == 0 {
// Set expiration if entry does not exist
e.exp = ts + expiration
} else if ts >= e.exp {
// Check if entry is expired
store.delete(key)
manager.delete(key)
// External storage saves body data with different key
if cfg.Storage != nil {
manager.delete(key + "_body")
}
} else {
// Seperate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
}
// Set response headers from cache
c.Response().Header.SetContentTypeBytes(e.cType)
c.Status(e.status).Send(e.body)
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
c.Response().Header.SetContentTypeBytes(e.ctype)
// Set Cache-Control header if enabled
if cfg.CacheControl {
@ -90,11 +104,21 @@ func New(config ...Config) fiber.Handler {
}
// Cache response
e.body = utils.SafeBytes(c.Response().Body())
e.status = c.Response().StatusCode()
e.body = utils.CopyBytes(c.Response().Body())
e.cType = utils.CopyBytes(c.Response().Header.ContentType())
e.ctype = utils.SafeBytes(c.Response().Header.ContentType())
store.set(key, e)
// For external Storage we store raw body seperated
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, cfg.Expiration)
// avoid body msgp encoding
e.body = nil
manager.set(key, e, cfg.Expiration)
manager.release(e)
} else {
// Store entry in memory
manager.set(key, e, cfg.Expiration)
}
// Finish response
return nil

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -282,7 +283,7 @@ func Benchmark_Cache_Storage(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
//// Store: memory.New(),
Storage: memory.New(),
}))
app.Get("/demo", func(c *fiber.Ctx) error {

121
middleware/cache/manager.go vendored Normal file
View File

@ -0,0 +1,121 @@
package cache
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type item struct {
body []byte
ctype []byte
status int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
e.body = nil
e.ctype = nil
e.status = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -7,7 +7,7 @@ import (
)
// DecodeMsg implements msgp.Decodable
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
@ -30,10 +30,10 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
err = msgp.WrapError(err, "body")
return
}
case "cType":
z.cType, err = dc.ReadBytes(z.cType)
case "ctype":
z.ctype, err = dc.ReadBytes(z.ctype)
if err != nil {
err = msgp.WrapError(err, "cType")
err = msgp.WrapError(err, "ctype")
return
}
case "status":
@ -60,7 +60,7 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
}
// EncodeMsg implements msgp.Encodable
func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 4
// write "body"
err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
@ -72,14 +72,14 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
err = msgp.WrapError(err, "body")
return
}
// write "cType"
err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
// write "ctype"
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
if err != nil {
return
}
err = en.WriteBytes(z.cType)
err = en.WriteBytes(z.ctype)
if err != nil {
err = msgp.WrapError(err, "cType")
err = msgp.WrapError(err, "ctype")
return
}
// write "status"
@ -106,15 +106,15 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
}
// MarshalMsg implements msgp.Marshaler
func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 4
// string "body"
o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
o = msgp.AppendBytes(o, z.body)
// string "cType"
o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.cType)
// string "ctype"
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.ctype)
// string "status"
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
o = msgp.AppendInt(o, z.status)
@ -125,7 +125,7 @@ func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
@ -148,10 +148,10 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
err = msgp.WrapError(err, "body")
return
}
case "cType":
z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType)
case "ctype":
z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype)
if err != nil {
err = msgp.WrapError(err, "cType")
err = msgp.WrapError(err, "ctype")
return
}
case "status":
@ -179,7 +179,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *entry) Msgsize() (s int) {
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
func (z *item) Msgsize() (s int) {
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
return
}

View File

@ -1,80 +0,0 @@
package cache
import (
"github.com/gofiber/fiber/v2/internal/mapstore"
)
// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
body []byte `msg:"body"`
cType []byte `msg:"cType"`
status int `msg:"status"`
exp uint64 `msg:"exp"`
}
//msgp:ignore storage
type storage struct {
cfg *Config
store *mapstore.MapStore
}
func newStorage(cfg *Config) *storage {
store := &storage{
cfg: cfg,
}
if cfg.Storage == nil {
store.store = mapstore.New()
}
return store
}
func (s *storage) get(key string) *entry {
if s.cfg.Storage != nil {
raw, err := s.cfg.Storage.Get(key)
if err != nil || raw == nil {
return nil
}
e := &entry{}
if _, err := e.UnmarshalMsg(raw); err != nil {
return nil
}
body, err := s.cfg.Storage.Get(key + "_body")
if err != nil || body == nil {
return nil
}
e.body = body
return e
} else {
val := s.store.Get(key)
if val != nil {
return val.(*entry)
}
}
return nil
}
func (s *storage) set(key string, e *entry) {
if s.cfg.Storage != nil {
// seperate body since we dont want to encode big payloads
body := e.body
e.body = nil
if data, err := e.MarshalMsg(nil); err == nil {
_ = s.cfg.Storage.Set(key, data, s.cfg.Expiration)
_ = s.cfg.Storage.Set(key+"_body", body, s.cfg.Expiration)
}
} else {
s.store.Set(key, e, s.cfg.Expiration)
}
}
func (s *storage) delete(key string) {
if s.cfg.Storage != nil {
_ = s.cfg.Storage.Delete(key)
_ = s.cfg.Storage.Delete(key + "_body")
} else {
s.store.Delete(key)
}
}

View File

@ -8,7 +8,6 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
)
// New creates a new middleware handler
@ -16,10 +15,8 @@ func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Set default values
if cfg.Storage == nil {
cfg.Storage = memory.New()
}
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
@ -45,6 +42,8 @@ func New(config ...Config) fiber.Handler {
extractor = csrfFromCookie(selectors[1])
}
dummyValue := []byte{'+'}
// Return new handler
return func(c *fiber.Ctx) (err error) {
// Don't execute middleware if Next returns true
@ -52,12 +51,6 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
// create storage handler
store := &storage{
cfg: &cfg,
entries: make(map[string][]byte),
}
var token string
// Action depends on the HTTP method
@ -72,7 +65,7 @@ func New(config ...Config) fiber.Handler {
token = cfg.KeyGenerator()
// Add token to Storage
store.set(token)
manager.setRaw(token, dummyValue, cfg.Expiration)
}
// Create cookie to pass token to client
@ -86,20 +79,19 @@ func New(config ...Config) fiber.Handler {
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
}
// Set cookie to response
c.Cookie(cookie)
case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut:
// Verify CSRF token
// Extract token from client request i.e. header, query, param, form or cookie
token, err = extractor(c)
if err != nil {
return fiber.ErrForbidden
}
// We have a problem extracting the csrf token from Storage
if store.get(token) {
// The token is invalid, let client generate a new one
store.delete(token)
// 403 if token does not exist in Storage
if manager.getRaw(token) == nil {
// Expire cookie
c.Cookie(&fiber.Cookie{
Name: cfg.CookieName,
@ -110,8 +102,13 @@ func New(config ...Config) fiber.Handler {
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
})
// Return 403 Forbidden
return fiber.ErrForbidden
}
// The token is validated, time to delete it
manager.delete(token)
}
// Protect clients from caching the response by telling the browser

112
middleware/csrf/manager.go Normal file
View File

@ -0,0 +1,112 @@
package csrf
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type item struct {
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -0,0 +1,90 @@
package csrf
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/gofiber/fiber/v2/internal/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 0
err = en.Append(0x80)
if err != nil {
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 0
o = append(o, 0x80)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1
return
}

View File

@ -1,51 +0,0 @@
package csrf
import (
"sync"
)
// We only use Keys in Storage, so we need a dummy value
var emptyByte = []byte{'+'}
type storage struct {
cfg *Config
mux *sync.RWMutex
entries map[string][]byte
}
func (s *storage) get(key string) bool {
if s.cfg.Storage != nil {
val, err := s.cfg.Storage.Get(key)
if err == nil && val != nil {
return true
}
} else {
s.mux.Lock()
_, ok := s.entries[key]
s.mux.Unlock()
if ok {
return true
}
}
return false
}
func (s *storage) set(key string) {
if s.cfg.Storage != nil {
_ = s.cfg.Storage.Set(key, emptyByte, s.cfg.Expiration)
} else {
s.mux.Lock()
s.entries[key] = emptyByte
s.mux.Unlock()
}
}
func (s *storage) delete(key string) {
if s.cfg.Storage != nil {
_ = s.cfg.Storage.Delete(key)
} else {
s.mux.Lock()
delete(s.entries, key)
s.mux.Unlock()
}
}

View File

@ -2,6 +2,7 @@ package limiter
import (
"strconv"
"sync"
"sync/atomic"
"time"
@ -24,17 +25,15 @@ func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
var (
// Limiter settings
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
// mux = &sync.RWMutex{}
// // Default store logic (if no Store is provided)
// entries = make(map[string]entry)
)
store := newStorage(&cfg)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
go func() {
@ -54,28 +53,12 @@ func New(config ...Config) fiber.Handler {
// Get key from request
key := cfg.KeyGenerator(c)
e := store.get(key)
// // Create new entry
// entry := entry{}
// Get entry from pool and release when finished
e := manager.get(key)
// // Lock entry
// mux.Lock()
// defer mux.Unlock()
// // Use Storage if provided
// if cfg.Storage != nil {
// val, err := cfg.Storage.Get(key)
// if val != nil && len(val) > 0 {
// if _, err := entry.UnmarshalMsg(val); err != nil {
// return err
// }
// }
// if err != nil && err.Error() != errNotExist {
// fmt.Println("[LIMITER]", err.Error())
// }
// } else {
// entry = entries[key]
// }
// Lock entry and unlock when finished
mux.Lock()
defer mux.Unlock()
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
@ -93,29 +76,15 @@ func New(config ...Config) fiber.Handler {
// Increment hits
e.hits++
store.set(key, e)
// // Use Storage if provided
// if cfg.Storage != nil {
// // Marshal entry to bytes
// val, err := entry.MarshalMsg(nil)
// if err != nil {
// return err
// }
// // Pass value to Storage
// if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
// return err
// }
// } else {
// entries[key] = entry
// }
// Calculate when it resets in seconds
expire := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - e.hits
// Update storage
manager.set(key, e, cfg.Expiration)
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header

View File

@ -0,0 +1,115 @@
package limiter
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type item struct {
hits int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
e.hits = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -7,7 +7,7 @@ import (
)
// DecodeMsg implements msgp.Decodable
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
@ -48,7 +48,7 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
}
// EncodeMsg implements msgp.Encodable
func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "hits"
err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
@ -74,7 +74,7 @@ func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
}
// MarshalMsg implements msgp.Marshaler
func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "hits"
@ -87,7 +87,7 @@ func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
@ -129,7 +129,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z entry) Msgsize() (s int) {
func (z item) Msgsize() (s int) {
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
return
}

View File

@ -1,69 +0,0 @@
package limiter
import (
"github.com/gofiber/fiber/v2/internal/mapstore"
)
// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
hits int `msg:"hits"`
exp uint64 `msg:"exp"`
}
//msgp:ignore storage
type storage struct {
cfg *Config
store *mapstore.Storage
}
func newStorage(cfg *Config) *storage {
store := &storage{
cfg: cfg,
}
if cfg.Storage == nil {
store.store = mapstore.New()
}
return store
}
func (s *storage) get(key string) (e entry) {
if s.cfg.Storage != nil {
raw, err := s.cfg.Storage.Get(key)
if err != nil || raw == nil {
return
}
if _, err := e.UnmarshalMsg(raw); err != nil {
return
}
return
} else {
// val := s.mem.Get(key).(*entry)
var ok bool
e, ok = s.store.Get(key).(entry)
if !ok {
return
}
}
return
}
func (s *storage) set(key string, e entry) {
if s.cfg.Storage != nil {
if data, err := e.MarshalMsg(nil); err == nil {
_ = s.cfg.Storage.Set(key, data, s.cfg.Expiration)
}
} else {
s.store.Set(key, e, s.cfg.Expiration)
}
}
func (s *storage) delete(key string) {
if s.cfg.Storage != nil {
_ = s.cfg.Storage.Delete(key)
} else {
s.store.Delete(key)
}
}