mirror of https://github.com/gofiber/fiber.git
📦 Switch to fiber.Storage and msgp
parent
43e100f36c
commit
66d2e7deda
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/gofiber/fiber/v2
|
|||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/philhofer/fwd v1.1.0
|
||||
github.com/valyala/fasthttp v1.16.0
|
||||
golang.org/x/sys v0.0.0-20201020230747-6e5568b54d1a
|
||||
)
|
||||
|
|
6
go.sum
6
go.sum
|
@ -1,11 +1,9 @@
|
|||
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
|
||||
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/andybalholm/brotli v1.0.1 h1:KqhlKozYbRtJvsPrrEeXcO+N2l6NYT5A2QAFmSULpEc=
|
||||
github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
|
||||
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.1 h1:bPb7nMRdOZYDrpPMTA3EInUQrdgoBinqUuSwlGdKDdE=
|
||||
github.com/klauspost/compress v1.11.1/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/philhofer/fwd v1.1.0 h1:PAdZw9+/BCf4gc/kA2L/PbGPkFe72Kl2GLZXTG8HpU8=
|
||||
github.com/philhofer/fwd v1.1.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.16.0 h1:9zAqOYLl8Tuy3E5R6ckzGDJ1g8+pw15oQp2iL9Jl6gQ=
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -11,6 +9,9 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
//go:generate msgp -unexported
|
||||
//msgp:ignore Config
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
|
@ -45,7 +46,7 @@ type Config struct {
|
|||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Store Storage
|
||||
Store fiber.Storage
|
||||
|
||||
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||
// execution time down.
|
||||
|
@ -152,11 +153,8 @@ func New(config ...Config) fiber.Handler {
|
|||
// Assume this means item not found.
|
||||
session = trackedSession{}
|
||||
} else {
|
||||
// Decode bytes using gob
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(fromStore)
|
||||
dec := gob.NewDecoder(&buf)
|
||||
err := dec.Decode(&session)
|
||||
// Decode bytes using msgp
|
||||
_, err := session.UnmarshalMsg(fromStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -180,15 +178,15 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
if cfg.usingCustomStore {
|
||||
// Convert session struct into bytes
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(session)
|
||||
|
||||
data, err := session.MarshalMsg(nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store those bytes
|
||||
err = cfg.Store.Set(key, buf.Bytes(), cfg.Duration)
|
||||
err = cfg.Store.Set(key, data, cfg.Duration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
package limiter
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *trackedSession) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "Hits":
|
||||
z.Hits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Hits")
|
||||
return
|
||||
}
|
||||
case "ResetTime":
|
||||
z.ResetTime, err = dc.ReadUint64()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "ResetTime")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z trackedSession) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 2
|
||||
// write "Hits"
|
||||
err = en.Append(0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.Hits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Hits")
|
||||
return
|
||||
}
|
||||
// write "ResetTime"
|
||||
err = en.Append(0xa9, 0x52, 0x65, 0x73, 0x65, 0x74, 0x54, 0x69, 0x6d, 0x65)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteUint64(z.ResetTime)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "ResetTime")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z trackedSession) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 2
|
||||
// string "Hits"
|
||||
o = append(o, 0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.Hits)
|
||||
// string "ResetTime"
|
||||
o = append(o, 0xa9, 0x52, 0x65, 0x73, 0x65, 0x74, 0x54, 0x69, 0x6d, 0x65)
|
||||
o = msgp.AppendUint64(o, z.ResetTime)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *trackedSession) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "Hits":
|
||||
z.Hits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Hits")
|
||||
return
|
||||
}
|
||||
case "ResetTime":
|
||||
z.ResetTime, bts, err = msgp.ReadUint64Bytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "ResetTime")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z trackedSession) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.IntSize + 10 + msgp.Uint64Size
|
||||
return
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package limiter
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||
)
|
||||
|
||||
func TestMarshalUnmarshaltrackedSession(t *testing.T) {
|
||||
v := trackedSession{}
|
||||
bts, err := v.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
left, err := v.UnmarshalMsg(bts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(left) > 0 {
|
||||
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
|
||||
}
|
||||
|
||||
left, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(left) > 0 {
|
||||
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalMsgtrackedSession(b *testing.B) {
|
||||
v := trackedSession{}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v.MarshalMsg(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendMsgtrackedSession(b *testing.B) {
|
||||
v := trackedSession{}
|
||||
bts := make([]byte, 0, v.Msgsize())
|
||||
bts, _ = v.MarshalMsg(bts[0:0])
|
||||
b.SetBytes(int64(len(bts)))
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bts, _ = v.MarshalMsg(bts[0:0])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmarshaltrackedSession(b *testing.B) {
|
||||
v := trackedSession{}
|
||||
bts, _ := v.MarshalMsg(nil)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(bts)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := v.UnmarshalMsg(bts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodetrackedSession(t *testing.T) {
|
||||
v := trackedSession{}
|
||||
var buf bytes.Buffer
|
||||
msgp.Encode(&buf, &v)
|
||||
|
||||
m := v.Msgsize()
|
||||
if buf.Len() > m {
|
||||
t.Log("WARNING: TestEncodeDecodetrackedSession Msgsize() is inaccurate")
|
||||
}
|
||||
|
||||
vn := trackedSession{}
|
||||
err := msgp.Decode(&buf, &vn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
msgp.Encode(&buf, &v)
|
||||
err = msgp.NewReader(&buf).Skip()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodetrackedSession(b *testing.B) {
|
||||
v := trackedSession{}
|
||||
var buf bytes.Buffer
|
||||
msgp.Encode(&buf, &v)
|
||||
b.SetBytes(int64(buf.Len()))
|
||||
en := msgp.NewWriter(msgp.Nowhere)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v.EncodeMsg(en)
|
||||
}
|
||||
en.Flush()
|
||||
}
|
||||
|
||||
func BenchmarkDecodetrackedSession(b *testing.B) {
|
||||
v := trackedSession{}
|
||||
var buf bytes.Buffer
|
||||
msgp.Encode(&buf, &v)
|
||||
b.SetBytes(int64(buf.Len()))
|
||||
rd := msgp.NewEndlessReader(buf.Bytes(), b)
|
||||
dc := msgp.NewReader(rd)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := v.DecodeMsg(dc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,6 +16,9 @@ import (
|
|||
|
||||
// go test -run Test_Limiter_Concurrency -race -v
|
||||
func Test_Limiter_Concurrency(t *testing.T) {
|
||||
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
|
@ -60,12 +63,14 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Test concurrency using a custom store
|
||||
|
||||
app = fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Duration: 2 * time.Second,
|
||||
Store: defaultStore{stmap: map[string][]byte{}},
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
|
@ -117,6 +122,33 @@ func Benchmark_Limiter(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
|
||||
func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Duration: 60 * time.Second,
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Next
|
||||
func Test_Limiter_Next(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
@ -156,3 +188,36 @@ func Test_Limiter_Headers(t *testing.T) {
|
|||
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
|
||||
}
|
||||
}
|
||||
|
||||
// testStore is used for testing custom stores
|
||||
type testStore struct {
|
||||
stmap map[string][]byte
|
||||
mutex *sync.Mutex
|
||||
}
|
||||
|
||||
func (s testStore) Get(id string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
val, ok := s.stmap[id]
|
||||
s.mutex.Unlock()
|
||||
if !ok {
|
||||
return []byte{}, nil
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
s.stmap[id] = val
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Clear() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Storage interface implemented by providers
|
||||
type Storage interface {
|
||||
// Get session value. If the ID is not found, this function should return
|
||||
// []byte{}, nil and not an error.
|
||||
Get(id string) ([]byte, error)
|
||||
// Set session value. `exp` will be zero for no duration.
|
||||
Set(id string, value []byte, exp time.Duration) error
|
||||
// Delete session value
|
||||
Delete(id string) error
|
||||
// Clear clears the store
|
||||
Clear() error
|
||||
}
|
||||
|
||||
type defaultStore struct {
|
||||
stmap map[string][]byte
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func (s defaultStore) Get(id string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
val, ok := s.stmap[id]
|
||||
s.mutex.Unlock()
|
||||
if !ok {
|
||||
return []byte{}, nil
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s defaultStore) Set(id string, val []byte, _ time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
s.stmap[id] = val
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s defaultStore) Clear() error {
|
||||
s.mutex.Lock()
|
||||
s.stmap = map[string][]byte{}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s defaultStore) Delete(id string) error {
|
||||
s.mutex.Lock()
|
||||
_, ok := s.stmap[id]
|
||||
if ok {
|
||||
delete(s.stmap, id)
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue