📦 Switch to fiber.Storage and msgp

pull/900/head
Tom 2020-10-26 21:25:30 +00:00
parent 43e100f36c
commit 66d2e7deda
No known key found for this signature in database
GPG Key ID: D3E7EAA31B39637E
7 changed files with 337 additions and 79 deletions

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/gofiber/fiber/v2
go 1.14
require (
github.com/philhofer/fwd v1.1.0
github.com/valyala/fasthttp v1.16.0
golang.org/x/sys v0.0.0-20201020230747-6e5568b54d1a
)

6
go.sum
View File

@ -1,11 +1,9 @@
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/andybalholm/brotli v1.0.1 h1:KqhlKozYbRtJvsPrrEeXcO+N2l6NYT5A2QAFmSULpEc=
github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.1 h1:bPb7nMRdOZYDrpPMTA3EInUQrdgoBinqUuSwlGdKDdE=
github.com/klauspost/compress v1.11.1/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/philhofer/fwd v1.1.0 h1:PAdZw9+/BCf4gc/kA2L/PbGPkFe72Kl2GLZXTG8HpU8=
github.com/philhofer/fwd v1.1.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.16.0 h1:9zAqOYLl8Tuy3E5R6ckzGDJ1g8+pw15oQp2iL9Jl6gQ=

View File

@ -1,8 +1,6 @@
package limiter
import (
"bytes"
"encoding/gob"
"strconv"
"sync"
"sync/atomic"
@ -11,6 +9,9 @@ import (
"github.com/gofiber/fiber/v2"
)
//go:generate msgp -unexported
//msgp:ignore Config
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
@ -45,7 +46,7 @@ type Config struct {
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store Storage
Store fiber.Storage
// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
@ -152,11 +153,8 @@ func New(config ...Config) fiber.Handler {
// Assume this means item not found.
session = trackedSession{}
} else {
// Decode bytes using gob
var buf bytes.Buffer
_, _ = buf.Write(fromStore)
dec := gob.NewDecoder(&buf)
err := dec.Decode(&session)
// Decode bytes using msgp
_, err := session.UnmarshalMsg(fromStore)
if err != nil {
return err
}
@ -180,15 +178,15 @@ func New(config ...Config) fiber.Handler {
if cfg.usingCustomStore {
// Convert session struct into bytes
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(session)
data, err := session.MarshalMsg(nil)
if err != nil {
return err
}
// Store those bytes
err = cfg.Store.Set(key, buf.Bytes(), cfg.Duration)
err = cfg.Store.Set(key, data, cfg.Duration)
if err != nil {
return err
}

View File

@ -0,0 +1,135 @@
package limiter
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/gofiber/fiber/v2/internal/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *trackedSession) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "Hits":
z.Hits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "Hits")
return
}
case "ResetTime":
z.ResetTime, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "ResetTime")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z trackedSession) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "Hits"
err = en.Append(0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.Hits)
if err != nil {
err = msgp.WrapError(err, "Hits")
return
}
// write "ResetTime"
err = en.Append(0xa9, 0x52, 0x65, 0x73, 0x65, 0x74, 0x54, 0x69, 0x6d, 0x65)
if err != nil {
return
}
err = en.WriteUint64(z.ResetTime)
if err != nil {
err = msgp.WrapError(err, "ResetTime")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z trackedSession) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "Hits"
o = append(o, 0x82, 0xa4, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.Hits)
// string "ResetTime"
o = append(o, 0xa9, 0x52, 0x65, 0x73, 0x65, 0x74, 0x54, 0x69, 0x6d, 0x65)
o = msgp.AppendUint64(o, z.ResetTime)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *trackedSession) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "Hits":
z.Hits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Hits")
return
}
case "ResetTime":
z.ResetTime, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "ResetTime")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z trackedSession) Msgsize() (s int) {
s = 1 + 5 + msgp.IntSize + 10 + msgp.Uint64Size
return
}

View File

@ -0,0 +1,123 @@
package limiter
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"bytes"
"testing"
"github.com/gofiber/fiber/v2/internal/msgp"
)
func TestMarshalUnmarshaltrackedSession(t *testing.T) {
v := trackedSession{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgtrackedSession(b *testing.B) {
v := trackedSession{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgtrackedSession(b *testing.B) {
v := trackedSession{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshaltrackedSession(b *testing.B) {
v := trackedSession{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodetrackedSession(t *testing.T) {
v := trackedSession{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodetrackedSession Msgsize() is inaccurate")
}
vn := trackedSession{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodetrackedSession(b *testing.B) {
v := trackedSession{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodetrackedSession(b *testing.B) {
v := trackedSession{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}

View File

@ -16,6 +16,9 @@ import (
// go test -run Test_Limiter_Concurrency -race -v
func Test_Limiter_Concurrency(t *testing.T) {
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
@ -60,12 +63,14 @@ func Test_Limiter_Concurrency(t *testing.T) {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
// Test concurrency using a custom store
app = fiber.New()
app.Use(New(Config{
Max: 50,
Duration: 2 * time.Second,
Store: defaultStore{stmap: map[string][]byte{}},
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
}))
app.Get("/", func(c *fiber.Ctx) error {
@ -117,6 +122,33 @@ func Benchmark_Limiter(b *testing.B) {
}
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
func Benchmark_Limiter_Custom_Store(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Duration: 60 * time.Second,
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -run Test_Limiter_Next
func Test_Limiter_Next(t *testing.T) {
app := fiber.New()
@ -156,3 +188,36 @@ func Test_Limiter_Headers(t *testing.T) {
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
}
}
// testStore is used for testing custom stores
type testStore struct {
stmap map[string][]byte
mutex *sync.Mutex
}
func (s testStore) Get(id string) ([]byte, error) {
s.mutex.Lock()
val, ok := s.stmap[id]
s.mutex.Unlock()
if !ok {
return []byte{}, nil
} else {
return val, nil
}
}
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()
return nil
}
func (s testStore) Clear() error {
return nil
}
func (s testStore) Delete(id string) error {
return nil
}

View File

@ -1,62 +0,0 @@
package limiter
import (
"sync"
"time"
)
// Storage interface implemented by providers
type Storage interface {
// Get session value. If the ID is not found, this function should return
// []byte{}, nil and not an error.
Get(id string) ([]byte, error)
// Set session value. `exp` will be zero for no duration.
Set(id string, value []byte, exp time.Duration) error
// Delete session value
Delete(id string) error
// Clear clears the store
Clear() error
}
type defaultStore struct {
stmap map[string][]byte
mutex sync.Mutex
}
func (s defaultStore) Get(id string) ([]byte, error) {
s.mutex.Lock()
val, ok := s.stmap[id]
s.mutex.Unlock()
if !ok {
return []byte{}, nil
} else {
return val, nil
}
}
func (s defaultStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()
return nil
}
func (s defaultStore) Clear() error {
s.mutex.Lock()
s.stmap = map[string][]byte{}
s.mutex.Unlock()
return nil
}
func (s defaultStore) Delete(id string) error {
s.mutex.Lock()
_, ok := s.stmap[id]
if ok {
delete(s.stmap, id)
}
s.mutex.Unlock()
return nil
}