mirror of https://github.com/gofiber/fiber.git
commit
2e30da966a
26
app.go
26
app.go
|
@ -39,27 +39,27 @@ type Handler = func(*Ctx) error
|
|||
// Map is a shortcut for map[string]interface{}, useful for JSON returns
|
||||
type Map map[string]interface{}
|
||||
|
||||
// Storage interface that is implemented by storage providers for different
|
||||
// middleware packages like cache, limiter, session and csrf
|
||||
// Storage interface for communicating with different database/key-value
|
||||
// providers
|
||||
type Storage interface {
|
||||
// Get retrieves the value for the given key.
|
||||
// If no value is not found it returns ErrNotExit error
|
||||
// Get gets the value for the given key.
|
||||
// It returns ErrNotFound if the storage does not contain the key.
|
||||
Get(key string) ([]byte, error)
|
||||
|
||||
// Set stores the given value for the given key along with a
|
||||
// time-to-live expiration value, 0 means live for ever
|
||||
// The key must not be "" and the empty values are ignored.
|
||||
// Empty key or value will be ignored without an error.
|
||||
Set(key string, val []byte, ttl time.Duration) error
|
||||
|
||||
// Delete deletes the stored value for the given key.
|
||||
// Deleting a non-existing key-value pair does NOT lead to an error.
|
||||
// The key must not be "".
|
||||
// Delete deletes the value for the given key.
|
||||
// It returns no error if the storage does not contain the key,
|
||||
Delete(key string) error
|
||||
|
||||
// Reset the storage
|
||||
// Reset resets the storage and delete all keys.
|
||||
Reset() error
|
||||
|
||||
// Close the storage
|
||||
// Close closes the storage and will stop any running garbage
|
||||
// collectors and open connections.
|
||||
Close() error
|
||||
}
|
||||
|
||||
|
@ -149,6 +149,8 @@ type Config struct {
|
|||
ETag bool `json:"etag"`
|
||||
|
||||
// Max body size that the server accepts.
|
||||
// -1 will decline any body size
|
||||
//
|
||||
// Default: 4 * 1024 * 1024
|
||||
BodyLimit int `json:"body_limit"`
|
||||
|
||||
|
@ -352,7 +354,7 @@ func New(config ...Config) *App {
|
|||
}
|
||||
|
||||
// Override default values
|
||||
if app.config.BodyLimit <= 0 {
|
||||
if app.config.BodyLimit == 0 {
|
||||
app.config.BodyLimit = DefaultBodyLimit
|
||||
}
|
||||
if app.config.Concurrency <= 0 {
|
||||
|
@ -373,8 +375,10 @@ func New(config ...Config) *App {
|
|||
if app.config.ErrorHandler == nil {
|
||||
app.config.ErrorHandler = DefaultErrorHandler
|
||||
}
|
||||
|
||||
// Init app
|
||||
app.init()
|
||||
|
||||
// Return app
|
||||
return app
|
||||
}
|
||||
|
|
3
go.mod
3
go.mod
|
@ -3,7 +3,6 @@ module github.com/gofiber/fiber/v2
|
|||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.11.0 // indirect
|
||||
github.com/valyala/fasthttp v1.17.0
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68
|
||||
)
|
||||
|
|
12
go.sum
12
go.sum
|
@ -1,9 +1,11 @@
|
|||
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
|
||||
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/andybalholm/brotli v1.0.1 h1:KqhlKozYbRtJvsPrrEeXcO+N2l6NYT5A2QAFmSULpEc=
|
||||
github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
|
||||
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg=
|
||||
github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.3 h1:dB4Bn0tN3wdCzQxnS8r06kV74qN/TAfaIS0bVE8h3jc=
|
||||
github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg=
|
||||
|
@ -14,13 +16,15 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg=
|
||||
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1 h1:a/mKvvZr9Jcc8oKfcmgzyp7OwF73JPWsQLvH1z2Kxck=
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
sync.RWMutex
|
||||
data map[string]item // data
|
||||
ts uint64 // timestamp
|
||||
}
|
||||
|
||||
type item struct {
|
||||
v interface{} // val
|
||||
e uint64 // exp
|
||||
}
|
||||
|
||||
func New() *Storage {
|
||||
store := &Storage{
|
||||
data: make(map[string]item),
|
||||
ts: uint64(time.Now().Unix()),
|
||||
}
|
||||
go store.gc(10 * time.Millisecond)
|
||||
go store.updater(1 * time.Second)
|
||||
return store
|
||||
}
|
||||
|
||||
// Get value by key
|
||||
func (s *Storage) Get(key string) interface{} {
|
||||
s.RLock()
|
||||
v, ok := s.data[key]
|
||||
s.RUnlock()
|
||||
if !ok || v.e != 0 && v.e <= atomic.LoadUint64(&s.ts) {
|
||||
return nil
|
||||
}
|
||||
return v.v
|
||||
}
|
||||
|
||||
// Set key with value
|
||||
func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
|
||||
var exp uint64
|
||||
if ttl > 0 {
|
||||
exp = uint64(ttl.Seconds()) + atomic.LoadUint64(&s.ts)
|
||||
}
|
||||
s.Lock()
|
||||
s.data[key] = item{val, exp}
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Delete key by key
|
||||
func (s *Storage) Delete(key string) {
|
||||
s.Lock()
|
||||
delete(s.data, key)
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Reset all keys
|
||||
func (s *Storage) Reset() {
|
||||
s.Lock()
|
||||
s.data = make(map[string]item)
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
func (s *Storage) updater(sleep time.Duration) {
|
||||
for {
|
||||
time.Sleep(sleep)
|
||||
atomic.StoreUint64(&s.ts, uint64(time.Now().Unix()))
|
||||
}
|
||||
}
|
||||
func (s *Storage) gc(sleep time.Duration) {
|
||||
expired := []string{}
|
||||
for {
|
||||
time.Sleep(sleep)
|
||||
expired = expired[:0]
|
||||
s.RLock()
|
||||
for key, v := range s.data {
|
||||
if v.e != 0 && v.e <= atomic.LoadUint64(&s.ts) {
|
||||
expired = append(expired, key)
|
||||
}
|
||||
}
|
||||
s.RUnlock()
|
||||
s.Lock()
|
||||
for i := range expired {
|
||||
delete(s.data, expired[i])
|
||||
}
|
||||
s.Unlock()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_Memory -v -race
|
||||
|
||||
func Test_Memory(t *testing.T) {
|
||||
var store = New()
|
||||
var (
|
||||
key = "john"
|
||||
val interface{} = []byte("doe")
|
||||
exp = 1 * time.Second
|
||||
)
|
||||
|
||||
store.Set(key, val, 0)
|
||||
store.Set(key, val, 0)
|
||||
|
||||
result := store.Get(key)
|
||||
utils.AssertEqual(t, val, result)
|
||||
|
||||
result = store.Get("empty")
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
store.Set(key, val, exp)
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
result = store.Get(key)
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
store.Set(key, val, 0)
|
||||
result = store.Get(key)
|
||||
utils.AssertEqual(t, val, result)
|
||||
|
||||
store.Delete(key)
|
||||
result = store.Get(key)
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
store.Set("john", val, 0)
|
||||
store.Set("doe", val, 0)
|
||||
store.Reset()
|
||||
|
||||
result = store.Get("john")
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
result = store.Get("doe")
|
||||
utils.AssertEqual(t, nil, result)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Memory -benchmem -count=4
|
||||
func Benchmark_Memory(b *testing.B) {
|
||||
keyLength := 1000
|
||||
keys := make([]string, keyLength)
|
||||
for i := 0; i < keyLength; i++ {
|
||||
keys[i] = utils.UUID()
|
||||
}
|
||||
value := []string{"some", "random", "value"}
|
||||
|
||||
ttl := 2 * time.Second
|
||||
b.Run("fiber_memory", func(b *testing.B) {
|
||||
d := New()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
for _, key := range keys {
|
||||
d.Set(key, value, ttl)
|
||||
}
|
||||
for _, key := range keys {
|
||||
_ = d.Get(key)
|
||||
}
|
||||
for _, key := range keys {
|
||||
d.Delete(key)
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
package memory
|
||||
|
||||
import "time"
|
||||
|
||||
// Config defines the config for storage.
|
||||
type Config struct {
|
||||
// Time before deleting expired keys
|
||||
//
|
||||
// Default is 10 * time.Second
|
||||
GCInterval time.Duration
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
GCInterval: 10 * time.Second,
|
||||
}
|
||||
|
||||
// configDefault is a helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.GCInterval.Seconds()) <= 0 {
|
||||
cfg.GCInterval = ConfigDefault.GCInterval
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -23,14 +23,11 @@ type entry struct {
|
|||
}
|
||||
|
||||
// New creates a new memory storage
|
||||
func New(config ...Config) *Storage {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
func New() *Storage {
|
||||
// Create storage
|
||||
store := &Storage{
|
||||
db: make(map[string]entry),
|
||||
gcInterval: cfg.GCInterval,
|
||||
gcInterval: 10 * time.Second,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
|
|
|
@ -61,12 +61,12 @@ type Config struct {
|
|||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.Path()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Store fiber.Storage
|
||||
Storage fiber.Storage
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -77,8 +77,9 @@ var ConfigDefault = Config{
|
|||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheControl: false,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.Path()
|
||||
},
|
||||
Storage: nil,
|
||||
}
|
||||
```
|
||||
|
|
|
@ -17,15 +17,21 @@ func New(config ...Config) fiber.Handler {
|
|||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Nothing to cache
|
||||
if int(cfg.Expiration.Seconds()) < 0 {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// Cache settings
|
||||
mux = &sync.RWMutex{}
|
||||
timestamp = uint64(time.Now().Unix())
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
mux = &sync.RWMutex{}
|
||||
|
||||
// Default store logic (if no Store is provided)
|
||||
entries = make(map[string]entry)
|
||||
)
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
go func() {
|
||||
|
@ -35,30 +41,6 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
}()
|
||||
|
||||
// Nothing to cache
|
||||
if int(cfg.Expiration.Seconds()) < 0 {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Remove expired entries
|
||||
if cfg.defaultStore {
|
||||
go func() {
|
||||
for {
|
||||
// GC the entries every 10 seconds
|
||||
time.Sleep(10 * time.Second)
|
||||
mux.Lock()
|
||||
for k := range entries {
|
||||
if atomic.LoadUint64(×tamp) >= entries[k].exp {
|
||||
delete(entries, k)
|
||||
}
|
||||
}
|
||||
mux.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
|
@ -72,74 +54,43 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.Key(c)
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Create new entry
|
||||
var entry entry
|
||||
var entryBody []byte
|
||||
// Get entry from pool
|
||||
e := manager.get(key)
|
||||
|
||||
// Lock entry
|
||||
// Lock entry and unlock when finished
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
// Check if we need to use the default in-memory storage
|
||||
if cfg.defaultStore {
|
||||
entry = entries[key]
|
||||
|
||||
} else {
|
||||
// Load data from store
|
||||
storeEntry, err := cfg.Storage.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only decode if we found an entry
|
||||
if storeEntry != nil {
|
||||
// Decode bytes using msgp
|
||||
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get timestamp
|
||||
ts := atomic.LoadUint64(×tamp)
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if entry.exp == 0 {
|
||||
entry.exp = ts + expiration
|
||||
if e.exp == 0 {
|
||||
// Set expiration if entry does not exist
|
||||
e.exp = ts + expiration
|
||||
|
||||
} else if ts >= entry.exp {
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
// Use default memory storage
|
||||
if cfg.defaultStore {
|
||||
delete(entries, key)
|
||||
} else { // Use custom storage
|
||||
if err := cfg.Storage.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cfg.Storage.Delete(key + "_body"); err != nil {
|
||||
return err
|
||||
}
|
||||
manager.delete(key)
|
||||
// External storage saves body data with different key
|
||||
if cfg.Storage != nil {
|
||||
manager.delete(key + "_body")
|
||||
}
|
||||
|
||||
} else {
|
||||
if cfg.defaultStore {
|
||||
c.Response().SetBodyRaw(entry.body)
|
||||
} else {
|
||||
c.Response().SetBodyRaw(entryBody)
|
||||
// Seperate body value to avoid msgp serialization
|
||||
// We can store raw bytes with Storage 👍
|
||||
if cfg.Storage != nil {
|
||||
e.body = manager.getRaw(key + "_body")
|
||||
}
|
||||
// Set response headers from cache
|
||||
c.Response().SetStatusCode(entry.status)
|
||||
c.Response().Header.SetContentTypeBytes(entry.cType)
|
||||
c.Response().SetBodyRaw(e.body)
|
||||
c.Response().SetStatusCode(e.status)
|
||||
c.Response().Header.SetContentTypeBytes(e.ctype)
|
||||
|
||||
// Set Cache-Control header if enabled
|
||||
if cfg.CacheControl {
|
||||
maxAge := strconv.FormatUint(entry.exp-ts, 10)
|
||||
maxAge := strconv.FormatUint(e.exp-ts, 10)
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
|
||||
}
|
||||
|
||||
|
@ -153,31 +104,20 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Cache response
|
||||
entryBody = utils.SafeBytes(c.Response().Body())
|
||||
entry.status = c.Response().StatusCode()
|
||||
entry.cType = utils.SafeBytes(c.Response().Header.ContentType())
|
||||
|
||||
// Use default memory storage
|
||||
if cfg.defaultStore {
|
||||
entry.body = entryBody
|
||||
entries[key] = entry
|
||||
e.body = utils.SafeBytes(c.Response().Body())
|
||||
e.status = c.Response().StatusCode()
|
||||
e.ctype = utils.SafeBytes(c.Response().Header.ContentType())
|
||||
|
||||
// For external Storage we store raw body seperated
|
||||
if cfg.Storage != nil {
|
||||
manager.setRaw(key+"_body", e.body, cfg.Expiration)
|
||||
// avoid body msgp encoding
|
||||
e.body = nil
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
manager.release(e)
|
||||
} else {
|
||||
// Use custom storage
|
||||
data, err := entry.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
// Store entry in memory
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
}
|
||||
|
||||
// Finish response
|
||||
|
|
|
@ -6,13 +6,12 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -93,54 +92,54 @@ func Test_Cache(t *testing.T) {
|
|||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
|
||||
// go test -run Test_Cache_Concurrency_Store -race -v
|
||||
func Test_Cache_Concurrency_Store(t *testing.T) {
|
||||
// Test concurrency using a custom store
|
||||
// // go test -run Test_Cache_Concurrency_Storage -race -v
|
||||
// func Test_Cache_Concurrency_Storage(t *testing.T) {
|
||||
// // Test concurrency using a custom store
|
||||
|
||||
app := fiber.New()
|
||||
// app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: &sync.RWMutex{}},
|
||||
}))
|
||||
// app.Use(New(Config{
|
||||
// Storage: memory.New(),
|
||||
// }))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
// app.Get("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendString("Hello tester!")
|
||||
// })
|
||||
|
||||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
// var wg sync.WaitGroup
|
||||
// singleRequest := func(wg *sync.WaitGroup) {
|
||||
// defer wg.Done()
|
||||
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
}
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
// }
|
||||
|
||||
for i := 0; i <= 49; i++ {
|
||||
wg.Add(1)
|
||||
go singleRequest(&wg)
|
||||
}
|
||||
// for i := 0; i <= 49; i++ {
|
||||
// wg.Add(1)
|
||||
// go singleRequest(&wg)
|
||||
// }
|
||||
|
||||
wg.Wait()
|
||||
// wg.Wait()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// req := httptest.NewRequest("GET", "/", nil)
|
||||
// resp, err := app.Test(req)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
|
||||
cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||
// cachedResp, err := app.Test(cachedReq)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// cachedBody, err := ioutil.ReadAll(cachedResp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
// utils.AssertEqual(t, cachedBody, body)
|
||||
// }
|
||||
|
||||
func Test_Cache_Invalid_Expiration(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
@ -235,7 +234,7 @@ func Test_Cache_NothingToCache(t *testing.T) {
|
|||
func Test_CustomKey(t *testing.T) {
|
||||
app := fiber.New()
|
||||
var called bool
|
||||
app.Use(New(Config{Key: func(c *fiber.Ctx) string {
|
||||
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
|
||||
called = true
|
||||
return c.Path()
|
||||
}}))
|
||||
|
@ -279,12 +278,12 @@ func Benchmark_Cache(b *testing.B) {
|
|||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Cache_Store -benchmem -count=4
|
||||
func Benchmark_Cache_Store(b *testing.B) {
|
||||
// go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4
|
||||
func Benchmark_Cache_Storage(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: &sync.RWMutex{}},
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
|
@ -308,43 +307,3 @@ func Benchmark_Cache_Store(b *testing.B) {
|
|||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
// testStore is used for testing custom stores
|
||||
type testStore struct {
|
||||
stmap map[string][]byte
|
||||
mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
func (s testStore) Get(id string) ([]byte, error) {
|
||||
s.mutex.RLock()
|
||||
val, ok := s.stmap[id]
|
||||
s.mutex.RUnlock()
|
||||
if !ok {
|
||||
return nil, nil
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
s.stmap[id] = val
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Reset() error {
|
||||
s.stmap = map[string][]byte{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Delete(id string) error {
|
||||
s.mutex.Lock()
|
||||
delete(s.stmap, id)
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -29,19 +29,18 @@ type Config struct {
|
|||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.Path()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// Deprecated, use Storage instead
|
||||
Store fiber.Storage
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||
// execution time down.
|
||||
defaultStore bool
|
||||
// Deprecated, use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated, use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
|
@ -49,10 +48,10 @@ var ConfigDefault = Config{
|
|||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheControl: false,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.Path()
|
||||
},
|
||||
defaultStore: true,
|
||||
Storage: nil,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
|
@ -66,22 +65,22 @@ func configDefault(config ...Config) Config {
|
|||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("[CACHE] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) == 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.Key == nil {
|
||||
cfg.Key = ConfigDefault.Key
|
||||
}
|
||||
if cfg.Storage == nil && cfg.Store == nil {
|
||||
cfg.defaultStore = true
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("cache: `Store` is deprecated, use `Storage` instead")
|
||||
cfg.Storage = cfg.Store
|
||||
cfg.defaultStore = true
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type item struct {
|
||||
body []byte
|
||||
ctype []byte
|
||||
status int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
// don't release item if we using memory storage
|
||||
if m.storage != nil {
|
||||
return
|
||||
}
|
||||
e.body = nil
|
||||
e.ctype = nil
|
||||
e.status = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
|
@ -30,10 +30,10 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||
err = msgp.WrapError(err, "body")
|
||||
return
|
||||
}
|
||||
case "cType":
|
||||
z.cType, err = dc.ReadBytes(z.cType)
|
||||
case "ctype":
|
||||
z.ctype, err = dc.ReadBytes(z.ctype)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cType")
|
||||
err = msgp.WrapError(err, "ctype")
|
||||
return
|
||||
}
|
||||
case "status":
|
||||
|
@ -60,7 +60,7 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 4
|
||||
// write "body"
|
||||
err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
|
@ -72,14 +72,14 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
|||
err = msgp.WrapError(err, "body")
|
||||
return
|
||||
}
|
||||
// write "cType"
|
||||
err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
|
||||
// write "ctype"
|
||||
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteBytes(z.cType)
|
||||
err = en.WriteBytes(z.ctype)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cType")
|
||||
err = msgp.WrapError(err, "ctype")
|
||||
return
|
||||
}
|
||||
// write "status"
|
||||
|
@ -106,15 +106,15 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
|||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 4
|
||||
// string "body"
|
||||
o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
o = msgp.AppendBytes(o, z.body)
|
||||
// string "cType"
|
||||
o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
|
||||
o = msgp.AppendBytes(o, z.cType)
|
||||
// string "ctype"
|
||||
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
o = msgp.AppendBytes(o, z.ctype)
|
||||
// string "status"
|
||||
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
|
||||
o = msgp.AppendInt(o, z.status)
|
||||
|
@ -125,7 +125,7 @@ func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
|
|||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
|
@ -148,10 +148,10 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||
err = msgp.WrapError(err, "body")
|
||||
return
|
||||
}
|
||||
case "cType":
|
||||
z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType)
|
||||
case "ctype":
|
||||
z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cType")
|
||||
err = msgp.WrapError(err, "ctype")
|
||||
return
|
||||
}
|
||||
case "status":
|
||||
|
@ -179,7 +179,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *entry) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
func (z *item) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package cache
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type entry struct {
|
||||
body []byte `msg:"body"`
|
||||
cType []byte `msg:"cType"`
|
||||
status int `msg:"status"`
|
||||
exp uint64 `msg:"exp"`
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
# Compress
|
||||
|
||||
Compression middleware for [Fiber](https://github.com/gofiber/fiber) that will compress the response using `gzip`, `deflate` and `brotli` compression depending on the [Accept-Encoding](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding) header.
|
||||
|
||||
- [Signatures](#signatures)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Cross-Origin Resource Sharing (CORS)
|
||||
|
||||
CORS middleware for [Fiber](https://github.com/gofiber/fiber) that that can be used to enable [Cross-Origin Resource Sharing](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) with various options.
|
||||
|
||||
### Table of Contents
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
|
@ -16,10 +15,8 @@ func New(config ...Config) fiber.Handler {
|
|||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Set default values
|
||||
if cfg.Storage == nil {
|
||||
cfg.Storage = memory.New()
|
||||
}
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Generate the correct extractor to get the token from the correct location
|
||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||
|
@ -45,8 +42,7 @@ func New(config ...Config) fiber.Handler {
|
|||
extractor = csrfFromCookie(selectors[1])
|
||||
}
|
||||
|
||||
// We only use Keys in Storage, so we need a dummy value
|
||||
dummyVal := []byte{'+'}
|
||||
dummyValue := []byte{'+'}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
|
@ -69,9 +65,7 @@ func New(config ...Config) fiber.Handler {
|
|||
token = cfg.KeyGenerator()
|
||||
|
||||
// Add token to Storage
|
||||
if err = cfg.Storage.Set(token, dummyVal, cfg.Expiration); err != nil {
|
||||
fmt.Println("[CSRF]", err.Error())
|
||||
}
|
||||
manager.setRaw(token, dummyValue, cfg.Expiration)
|
||||
}
|
||||
|
||||
// Create cookie to pass token to client
|
||||
|
@ -85,22 +79,19 @@ func New(config ...Config) fiber.Handler {
|
|||
HTTPOnly: cfg.CookieHTTPOnly,
|
||||
SameSite: cfg.CookieSameSite,
|
||||
}
|
||||
|
||||
// Set cookie to response
|
||||
c.Cookie(cookie)
|
||||
|
||||
case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut:
|
||||
// Verify CSRF token
|
||||
// Extract token from client request i.e. header, query, param, form or cookie
|
||||
token, err = extractor(c)
|
||||
if err != nil {
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
// We have a problem extracting the csrf token from Storage
|
||||
if _, err = cfg.Storage.Get(token); err != nil {
|
||||
// The token is invalid, let client generate a new one
|
||||
if err = cfg.Storage.Delete(token); err != nil {
|
||||
fmt.Println("[CSRF]", err.Error())
|
||||
}
|
||||
|
||||
// 403 if token does not exist in Storage
|
||||
if manager.getRaw(token) == nil {
|
||||
|
||||
// Expire cookie
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: cfg.CookieName,
|
||||
|
@ -111,8 +102,13 @@ func New(config ...Config) fiber.Handler {
|
|||
HTTPOnly: cfg.CookieHTTPOnly,
|
||||
SameSite: cfg.CookieSameSite,
|
||||
})
|
||||
|
||||
// Return 403 Forbidden
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
|
||||
// The token is validated, time to delete it
|
||||
manager.delete(token)
|
||||
}
|
||||
|
||||
// Protect clients from caching the response by telling the browser
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type item struct {
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
// don't release item if we using memory storage
|
||||
if m.storage != nil {
|
||||
return
|
||||
}
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package csrf
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 0
|
||||
err = en.Append(0x80)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 0
|
||||
o = append(o, 0x80)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1
|
||||
return
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
# ETag
|
||||
|
||||
ETag middleware for [Fiber](https://github.com/gofiber/fiber) that lets caches be more efficient and save bandwidth, as a web server does not need to resend a full response if the content has not changed.
|
||||
|
||||
### Table of Contents
|
||||
|
@ -37,6 +38,11 @@ app.Get("/", func(c *fiber.Ctx) error {
|
|||
```go
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
|
@ -46,18 +52,13 @@ type Config struct {
|
|||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
```
|
||||
|
||||
### Default Config
|
||||
```go
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
Weak: false,
|
||||
}
|
||||
```
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package etag
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
// to generate efficiently. Weak ETag values of two representations
|
||||
// of the same resources might be semantically equivalent, but not
|
||||
// byte-for-byte identical. This means weak etags prevent caching
|
||||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
return cfg
|
||||
}
|
|
@ -8,42 +8,13 @@ import (
|
|||
"github.com/gofiber/fiber/v2/internal/bytebufferpool"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
// to generate efficiently. Weak ETag values of two representations
|
||||
// of the same resources might be semantically equivalent, but not
|
||||
// byte-for-byte identical. This means weak etags prevent caching
|
||||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
var normalizedHeaderETag = []byte("Etag")
|
||||
var weakPrefix = []byte("W/")
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
cfg := configDefault(config...)
|
||||
|
||||
var crc32q = crc32.MakeTable(0xD5828281)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -26,16 +25,16 @@ func New(config ...Config) fiber.Handler {
|
|||
cfg := configDefault(config...)
|
||||
|
||||
var (
|
||||
// Limiter settings
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
timestamp = uint64(time.Now().Unix())
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
mux = &sync.RWMutex{}
|
||||
|
||||
// Default store logic (if no Store is provided)
|
||||
entries = make(map[string]entry)
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
go func() {
|
||||
for {
|
||||
|
@ -54,65 +53,39 @@ func New(config ...Config) fiber.Handler {
|
|||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Create new entry
|
||||
entry := entry{}
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
// Use Storage if provided
|
||||
if cfg.Storage != nil {
|
||||
val, err := cfg.Storage.Get(key)
|
||||
if val != nil && len(val) > 0 {
|
||||
if _, err := entry.UnmarshalMsg(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err != nil && err.Error() != errNotExist {
|
||||
fmt.Println("[LIMITER]", err.Error())
|
||||
}
|
||||
} else {
|
||||
entry = entries[key]
|
||||
}
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := atomic.LoadUint64(×tamp)
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if entry.exp == 0 {
|
||||
entry.exp = ts + expiration
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
|
||||
} else if ts >= entry.exp {
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
entry.hits = 0
|
||||
entry.exp = ts + expiration
|
||||
e.hits = 0
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
entry.hits++
|
||||
|
||||
// Use Storage if provided
|
||||
if cfg.Storage != nil {
|
||||
// Marshal entry to bytes
|
||||
val, err := entry.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass value to Storage
|
||||
if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
entries[key] = entry
|
||||
}
|
||||
e.hits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
expire := entry.exp - ts
|
||||
expire := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - entry.hits
|
||||
remaining := cfg.Max - e.hits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -172,7 +171,6 @@ func Test_Limiter_Headers(t *testing.T) {
|
|||
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
|
||||
}
|
||||
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
|
||||
fmt.Println(v)
|
||||
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type item struct {
|
||||
hits int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
e.hits = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
|
@ -48,7 +48,7 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 2
|
||||
// write "hits"
|
||||
err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
|
||||
|
@ -74,7 +74,7 @@ func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
|
|||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 2
|
||||
// string "hits"
|
||||
|
@ -87,7 +87,7 @@ func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
|
|||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
|
@ -129,7 +129,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z entry) Msgsize() (s int) {
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package limiter
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type entry struct {
|
||||
hits int `msg:"hits"`
|
||||
exp uint64 `msg:"exp"`
|
||||
}
|
|
@ -129,7 +129,7 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// Set error handler once
|
||||
once.Do(func() {
|
||||
errHandler = c.App().Config().ErrorHandler
|
||||
// get longested possible path
|
||||
stack := c.App().Stack()
|
||||
for m := range stack {
|
||||
for r := range stack[m] {
|
||||
|
@ -139,7 +139,8 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// override error handler
|
||||
errHandler = c.App().Config().ErrorHandler
|
||||
})
|
||||
|
||||
// Set latency start time
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package session
|
||||
|
||||
import "sync"
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="db.go" -o="db_msgp.go" -tests=false -unexported
|
||||
// msgp -file="data.go" -o="data_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type db struct {
|
||||
type data struct {
|
||||
d []kv
|
||||
}
|
||||
|
||||
|
@ -14,11 +16,26 @@ type kv struct {
|
|||
v interface{}
|
||||
}
|
||||
|
||||
func (d *db) Reset() {
|
||||
var dataPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(data)
|
||||
},
|
||||
}
|
||||
|
||||
func acquireData() *data {
|
||||
return dataPool.Get().(*data)
|
||||
}
|
||||
|
||||
func releaseData(d *data) {
|
||||
d.Reset()
|
||||
dataPool.Put(d)
|
||||
}
|
||||
|
||||
func (d *data) Reset() {
|
||||
d.d = d.d[:0]
|
||||
}
|
||||
|
||||
func (d *db) Get(key string) interface{} {
|
||||
func (d *data) Get(key string) interface{} {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
return d.d[idx].v
|
||||
|
@ -26,7 +43,7 @@ func (d *db) Get(key string) interface{} {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *db) Set(key string, value interface{}) {
|
||||
func (d *data) Set(key string, value interface{}) {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
kv := &d.d[idx]
|
||||
|
@ -36,7 +53,7 @@ func (d *db) Set(key string, value interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func (d *db) Delete(key string) {
|
||||
func (d *data) Delete(key string) {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
n := len(d.d) - 1
|
||||
|
@ -45,11 +62,11 @@ func (d *db) Delete(key string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (d *db) Len() int {
|
||||
func (d *data) Len() int {
|
||||
return len(d.d)
|
||||
}
|
||||
|
||||
func (d *db) swap(i, j int) {
|
||||
func (d *data) swap(i, j int) {
|
||||
iKey, iValue := d.d[i].k, d.d[i].v
|
||||
jKey, jValue := d.d[j].k, d.d[j].v
|
||||
|
||||
|
@ -57,7 +74,7 @@ func (d *db) swap(i, j int) {
|
|||
d.d[j].k, d.d[j].v = iKey, iValue
|
||||
}
|
||||
|
||||
func (d *db) allocPage() *kv {
|
||||
func (d *data) allocPage() *kv {
|
||||
n := len(d.d)
|
||||
if cap(d.d) > n {
|
||||
d.d = d.d[:n+1]
|
||||
|
@ -67,13 +84,13 @@ func (d *db) allocPage() *kv {
|
|||
return &d.d[n]
|
||||
}
|
||||
|
||||
func (d *db) append(key string, value interface{}) {
|
||||
func (d *data) append(key string, value interface{}) {
|
||||
kv := d.allocPage()
|
||||
kv.k = key
|
||||
kv.v = value
|
||||
}
|
||||
|
||||
func (d *db) indexOf(key string) int {
|
||||
func (d *data) indexOf(key string) int {
|
||||
n := len(d.d)
|
||||
for i := 0; i < n; i++ {
|
||||
if d.d[i].k == key {
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *db) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
func (z *data) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
|
@ -84,7 +84,7 @@ func (z *db) DecodeMsg(dc *msgp.Reader) (err error) {
|
|||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *db) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
func (z *data) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 1
|
||||
// write "d"
|
||||
err = en.Append(0x81, 0xa1, 0x64)
|
||||
|
@ -123,7 +123,7 @@ func (z *db) EncodeMsg(en *msgp.Writer) (err error) {
|
|||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *db) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
func (z *data) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 1
|
||||
// string "d"
|
||||
|
@ -146,7 +146,7 @@ func (z *db) MarshalMsg(b []byte) (o []byte, err error) {
|
|||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
|
@ -224,7 +224,7 @@ func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
|||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *db) Msgsize() (s int) {
|
||||
func (z *data) Msgsize() (s int) {
|
||||
s = 1 + 2 + msgp.ArrayHeaderSize
|
||||
for za0001 := range z.d {
|
||||
s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v)
|
|
@ -10,11 +10,11 @@ import (
|
|||
)
|
||||
|
||||
type Session struct {
|
||||
ctx *fiber.Ctx
|
||||
config *Store
|
||||
db *db
|
||||
id string
|
||||
fresh bool
|
||||
id string // session id
|
||||
fresh bool // if new session
|
||||
ctx *fiber.Ctx // fiber context
|
||||
config *Store // store configuration
|
||||
data *data // key value data
|
||||
}
|
||||
|
||||
var sessionPool = sync.Pool{
|
||||
|
@ -25,19 +25,20 @@ var sessionPool = sync.Pool{
|
|||
|
||||
func acquireSession() *Session {
|
||||
s := sessionPool.Get().(*Session)
|
||||
s.db = new(db)
|
||||
if s.data == nil {
|
||||
s.data = new(data)
|
||||
}
|
||||
s.fresh = true
|
||||
return s
|
||||
}
|
||||
|
||||
func releaseSession(s *Session) {
|
||||
s.id = ""
|
||||
s.ctx = nil
|
||||
s.config = nil
|
||||
if s.db != nil {
|
||||
s.db.Reset()
|
||||
if s.data != nil {
|
||||
s.data.Reset()
|
||||
}
|
||||
s.id = ""
|
||||
s.fresh = true
|
||||
sessionPool.Put(s)
|
||||
}
|
||||
|
||||
|
@ -53,25 +54,42 @@ func (s *Session) ID() string {
|
|||
|
||||
// Get will return the value
|
||||
func (s *Session) Get(key string) interface{} {
|
||||
return s.db.Get(key)
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
return s.data.Get(key)
|
||||
}
|
||||
|
||||
// Set will update or create a new key value
|
||||
func (s *Session) Set(key string, val interface{}) {
|
||||
s.db.Set(key, val)
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return
|
||||
}
|
||||
s.data.Set(key, val)
|
||||
}
|
||||
|
||||
// Delete will delete the value
|
||||
func (s *Session) Delete(key string) {
|
||||
s.db.Delete(key)
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return
|
||||
}
|
||||
s.data.Delete(key)
|
||||
}
|
||||
|
||||
// Destroy will delete the session from Storage and expire session cookie
|
||||
func (s *Session) Destroy() error {
|
||||
// Reset local data
|
||||
s.db.Reset()
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete data from storage
|
||||
// Reset local data
|
||||
s.data.Reset()
|
||||
|
||||
// Use external Storage if exist
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -88,6 +106,7 @@ func (s *Session) Regenerate() error {
|
|||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create new ID
|
||||
s.id = s.config.KeyGenerator()
|
||||
|
||||
|
@ -96,13 +115,18 @@ func (s *Session) Regenerate() error {
|
|||
|
||||
// Save will update the storage and client cookie
|
||||
func (s *Session) Save() error {
|
||||
// Don't save to Storage if no data is available
|
||||
if s.db.Len() <= 0 {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert book to bytes
|
||||
data, err := s.db.MarshalMsg(nil)
|
||||
// Don't save to Storage if no data is available
|
||||
if s.data.Len() <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert data to bytes
|
||||
data, err := s.data.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -115,7 +139,8 @@ func (s *Session) Save() error {
|
|||
// Create cookie with the session ID
|
||||
s.setCookie()
|
||||
|
||||
// release session to pool to be re-used on next request
|
||||
// Release session
|
||||
// TODO: It's not safe to use the Session after called Save()
|
||||
releaseSession(s)
|
||||
|
||||
return nil
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
@ -170,3 +171,34 @@ func Test_Session_Cookie(t *testing.T) {
|
|||
// cookie should not be set if empty data
|
||||
utils.AssertEqual(t, 0, len(ctx.Response().Header.PeekCookie(store.CookieName)))
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
|
||||
func Benchmark_Session(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.CookieName, "12356789")
|
||||
|
||||
b.Run("default", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c)
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
store = New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c)
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -42,19 +42,21 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
|||
sess.ctx = c
|
||||
sess.config = s
|
||||
sess.id = id
|
||||
sess.fresh = fresh
|
||||
|
||||
// Fetch existing data
|
||||
if !fresh {
|
||||
raw, err := s.Storage.Get(id)
|
||||
// Unmashal if we found data
|
||||
if err == nil {
|
||||
if _, err = sess.db.UnmarshalMsg(raw); err != nil {
|
||||
if _, err = sess.data.UnmarshalMsg(raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sess.fresh = false
|
||||
} else if err.Error() != errNotExist {
|
||||
// Only return error if it's not ErrNotExist
|
||||
} else if raw != nil && err.Error() != "key does not exist" {
|
||||
return nil, err
|
||||
} else {
|
||||
sess.fresh = true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ package utils
|
|||
|
||||
import "testing"
|
||||
|
||||
func Test_Utils_AssertEqual(t *testing.T) {
|
||||
func Test_AssertEqual(t *testing.T) {
|
||||
t.Parallel()
|
||||
AssertEqual(nil, []string{}, []string{})
|
||||
AssertEqual(t, []string{}, []string{})
|
||||
|
|
|
@ -56,7 +56,7 @@ func TrimBytes(b []byte, cutset byte) []byte {
|
|||
}
|
||||
|
||||
// EqualFold the equivalent of bytes.EqualFold
|
||||
func EqualsFold(b, s []byte) (equals bool) {
|
||||
func EqualFoldBytes(b, s []byte) (equals bool) {
|
||||
n := len(b)
|
||||
equals = n == len(s)
|
||||
if equals {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_ToLowerBytes(t *testing.T) {
|
||||
func Test_ToLowerBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToLowerBytes([]byte("/MY/NAME/IS/:PARAM/*"))
|
||||
AssertEqual(t, true, bytes.Equal([]byte("/my/name/is/:param/*"), res))
|
||||
|
@ -41,7 +41,7 @@ func Benchmark_ToLowerBytes(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_ToUpperBytes(t *testing.T) {
|
||||
func Test_ToUpperBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToUpperBytes([]byte("/my/name/is/:param/*"))
|
||||
AssertEqual(t, true, bytes.Equal([]byte("/MY/NAME/IS/:PARAM/*"), res))
|
||||
|
@ -73,7 +73,7 @@ func Benchmark_ToUpperBytes(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimRightBytes(t *testing.T) {
|
||||
func Test_TrimRightBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimRightBytes([]byte("/test//////"), '/')
|
||||
AssertEqual(t, []byte("/test"), res)
|
||||
|
@ -99,7 +99,7 @@ func Benchmark_TrimRightBytes(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimLeftBytes(t *testing.T) {
|
||||
func Test_TrimLeftBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimLeftBytes([]byte("////test/"), '/')
|
||||
AssertEqual(t, []byte("test/"), res)
|
||||
|
@ -123,7 +123,7 @@ func Benchmark_TrimLeftBytes(b *testing.B) {
|
|||
AssertEqual(b, []byte("foobar"), res)
|
||||
})
|
||||
}
|
||||
func Test_Utils_TrimBytes(t *testing.T) {
|
||||
func Test_TrimBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimBytes([]byte(" test "), ' ')
|
||||
AssertEqual(t, []byte("test"), res)
|
||||
|
@ -151,14 +151,14 @@ func Benchmark_TrimBytes(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Benchmark_EqualFolds(b *testing.B) {
|
||||
func Benchmark_EqualFoldBytes(b *testing.B) {
|
||||
var left = []byte("/RePos/GoFiBer/FibEr/iSsues/187643/CoMmEnts")
|
||||
var right = []byte("/RePos/goFiber/Fiber/issues/187643/COMMENTS")
|
||||
var res bool
|
||||
|
||||
b.Run("fiber", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = EqualsFold(left, right)
|
||||
res = EqualFoldBytes(left, right)
|
||||
}
|
||||
AssertEqual(b, true, res)
|
||||
})
|
||||
|
@ -170,18 +170,18 @@ func Benchmark_EqualFolds(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_EqualsFold(t *testing.T) {
|
||||
func Test_EqualFoldBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := EqualsFold([]byte("/MY/NAME/IS/:PARAM/*"), []byte("/my/name/is/:param/*"))
|
||||
res := EqualFoldBytes([]byte("/MY/NAME/IS/:PARAM/*"), []byte("/my/name/is/:param/*"))
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualsFold([]byte("/MY1/NAME/IS/:PARAM/*"), []byte("/MY1/NAME/IS/:PARAM/*"))
|
||||
res = EqualFoldBytes([]byte("/MY1/NAME/IS/:PARAM/*"), []byte("/MY1/NAME/IS/:PARAM/*"))
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualsFold([]byte("/my2/name/is/:param/*"), []byte("/my2/name"))
|
||||
res = EqualFoldBytes([]byte("/my2/name/is/:param/*"), []byte("/my2/name"))
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualsFold([]byte("/dddddd"), []byte("eeeeee"))
|
||||
res = EqualFoldBytes([]byte("/dddddd"), []byte("eeeeee"))
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualsFold([]byte("/MY3/NAME/IS/:PARAM/*"), []byte("/my3/name/is/:param/*"))
|
||||
res = EqualFoldBytes([]byte("/MY3/NAME/IS/:PARAM/*"), []byte("/my3/name/is/:param/*"))
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualsFold([]byte("/MY4/NAME/IS/:PARAM/*"), []byte("/my4/nAME/IS/:param/*"))
|
||||
res = EqualFoldBytes([]byte("/MY4/NAME/IS/:PARAM/*"), []byte("/my4/nAME/IS/:param/*"))
|
||||
AssertEqual(t, true, res)
|
||||
}
|
||||
|
|
|
@ -10,24 +10,24 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_FunctionName(t *testing.T) {
|
||||
func Test_FunctionName(t *testing.T) {
|
||||
t.Parallel()
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_Utils_UUID", FunctionName(Test_Utils_UUID))
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_UUID", FunctionName(Test_UUID))
|
||||
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_Utils_FunctionName.func1", FunctionName(func() {}))
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_FunctionName.func1", FunctionName(func() {}))
|
||||
|
||||
var dummyint = 20
|
||||
AssertEqual(t, "int", FunctionName(dummyint))
|
||||
}
|
||||
|
||||
func Test_Utils_UUID(t *testing.T) {
|
||||
func Test_UUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := UUID()
|
||||
AssertEqual(t, 36, len(res))
|
||||
AssertEqual(t, true, res != "00000000-0000-0000-0000-000000000000")
|
||||
}
|
||||
|
||||
func Test_Utils_UUID_Concurrency(t *testing.T) {
|
||||
func Test_UUID_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
iterations := 10000
|
||||
var res string
|
||||
|
|
|
@ -28,13 +28,13 @@ func UnsafeBytes(s string) (bs []byte) {
|
|||
return
|
||||
}
|
||||
|
||||
// SafeString copies a string to make it immutable
|
||||
func SafeString(s string) string {
|
||||
// CopyString copies a string to make it immutable
|
||||
func CopyString(s string) string {
|
||||
return string(UnsafeBytes(s))
|
||||
}
|
||||
|
||||
// SafeBytes copies a slice to make it immutable
|
||||
func SafeBytes(b []byte) []byte {
|
||||
// CopyBytes copies a slice to make it immutable
|
||||
func CopyBytes(b []byte) []byte {
|
||||
tmp := make([]byte, len(b))
|
||||
copy(tmp, b)
|
||||
return tmp
|
||||
|
@ -83,22 +83,3 @@ func ByteSize(bytes uint64) string {
|
|||
result = strings.TrimSuffix(result, ".0")
|
||||
return result + unit
|
||||
}
|
||||
|
||||
// Deprecated fn's
|
||||
|
||||
// #nosec G103
|
||||
// GetString returns a string pointer without allocation
|
||||
func GetString(b []byte) string {
|
||||
return UnsafeString(b)
|
||||
}
|
||||
|
||||
// #nosec G103
|
||||
// GetBytes returns a byte pointer without allocation
|
||||
func GetBytes(s string) []byte {
|
||||
return UnsafeBytes(s)
|
||||
}
|
||||
|
||||
// ImmutableString copies a string to make it immutable
|
||||
func ImmutableString(s string) string {
|
||||
return SafeString(s)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ package utils
|
|||
|
||||
import "testing"
|
||||
|
||||
func Test_Utils_GetString(t *testing.T) {
|
||||
func Test_GetString(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := GetString([]byte("Hello, World!"))
|
||||
AssertEqual(t, "Hello, World!", res)
|
||||
|
@ -31,7 +31,7 @@ func Benchmark_GetString(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_GetBytes(t *testing.T) {
|
||||
func Test_GetBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := GetBytes("Hello, World!")
|
||||
AssertEqual(t, []byte("Hello, World!"), res)
|
||||
|
@ -56,7 +56,7 @@ func Benchmark_GetBytes(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_ImmutableString(t *testing.T) {
|
||||
func Test_ImmutableString(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ImmutableString("Hello, World!")
|
||||
AssertEqual(t, "Hello, World!", res)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package utils
|
||||
|
||||
// #nosec G103
|
||||
// DEPRECATED, Please use UnsafeString instead
|
||||
func GetString(b []byte) string {
|
||||
return UnsafeString(b)
|
||||
}
|
||||
|
||||
// #nosec G103
|
||||
// DEPRECATED, Please use UnsafeBytes instead
|
||||
func GetBytes(s string) []byte {
|
||||
return UnsafeBytes(s)
|
||||
}
|
||||
|
||||
// DEPRECATED, Please use CopyString instead
|
||||
func ImmutableString(s string) string {
|
||||
return CopyString(s)
|
||||
}
|
||||
|
||||
// DEPRECATED, please use EqualFoldBytes
|
||||
func EqualsFold(b, s []byte) (equals bool) {
|
||||
return EqualFoldBytes(b, s)
|
||||
}
|
||||
|
||||
// DEPRECATED, Please use CopyString instead
|
||||
func SafeString(s string) string {
|
||||
return CopyString(s)
|
||||
}
|
||||
|
||||
// DEPRECATED, Please use CopyBytes instead
|
||||
func SafeBytes(b []byte) []byte {
|
||||
return CopyBytes(b)
|
||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_GetMIME(t *testing.T) {
|
||||
func Test_GetMIME(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := GetMIME(".json")
|
||||
AssertEqual(t, "application/json", res)
|
||||
|
@ -53,7 +53,7 @@ func Benchmark_GetMIME(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_StatusMessage(t *testing.T) {
|
||||
func Test_StatusMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := StatusMessage(204)
|
||||
AssertEqual(t, "No Content", res)
|
||||
|
|
|
@ -60,3 +60,17 @@ func TrimRight(s string, cutset byte) string {
|
|||
}
|
||||
return s[:lenStr]
|
||||
}
|
||||
|
||||
// EqualFold the equivalent of strings.EqualFold
|
||||
func EqualFold(b, s string) (equals bool) {
|
||||
n := len(b)
|
||||
equals = n == len(s)
|
||||
if equals {
|
||||
for i := 0; i < n; i++ {
|
||||
if equals = b[i]|0x20 == s[i]|0x20; !equals {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_ToUpper(t *testing.T) {
|
||||
func Test_ToUpper(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToUpper("/my/name/is/:param/*")
|
||||
AssertEqual(t, "/MY/NAME/IS/:PARAM/*", res)
|
||||
|
@ -33,7 +33,7 @@ func Benchmark_ToUpper(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_ToLower(t *testing.T) {
|
||||
func Test_ToLower(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToLower("/MY/NAME/IS/:PARAM/*")
|
||||
AssertEqual(t, "/my/name/is/:param/*", res)
|
||||
|
@ -64,7 +64,7 @@ func Benchmark_ToLower(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimRight(t *testing.T) {
|
||||
func Test_TrimRight(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimRight("/test//////", '/')
|
||||
AssertEqual(t, "/test", res)
|
||||
|
@ -89,7 +89,7 @@ func Benchmark_TrimRight(b *testing.B) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimLeft(t *testing.T) {
|
||||
func Test_TrimLeft(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimLeft("////test/", '/')
|
||||
AssertEqual(t, "test/", res)
|
||||
|
@ -113,7 +113,7 @@ func Benchmark_TrimLeft(b *testing.B) {
|
|||
AssertEqual(b, "foobar", res)
|
||||
})
|
||||
}
|
||||
func Test_Utils_Trim(t *testing.T) {
|
||||
func Test_Trim(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := Trim(" test ", ' ')
|
||||
AssertEqual(t, "test", res)
|
||||
|
@ -147,3 +147,39 @@ func Benchmark_Trim(b *testing.B) {
|
|||
AssertEqual(b, "foobar", res)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_EqualFold -benchmem -count=4
|
||||
func Benchmark_EqualFold(b *testing.B) {
|
||||
var left = "/RePos/GoFiBer/FibEr/iSsues/187643/CoMmEnts"
|
||||
var right = "/RePos/goFiber/Fiber/issues/187643/COMMENTS"
|
||||
var res bool
|
||||
|
||||
b.Run("fiber", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = EqualFold(left, right)
|
||||
}
|
||||
AssertEqual(b, true, res)
|
||||
})
|
||||
b.Run("default", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = strings.EqualFold(left, right)
|
||||
}
|
||||
AssertEqual(b, true, res)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_EqualFold(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := EqualFold("/MY/NAME/IS/:PARAM/*", "/my/name/is/:param/*")
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualFold("/MY1/NAME/IS/:PARAM/*", "/MY1/NAME/IS/:PARAM/*")
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualFold("/my2/name/is/:param/*", "/my2/name")
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualFold("/dddddd", "eeeeee")
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualFold("/MY3/NAME/IS/:PARAM/*", "/my3/name/is/:param/*")
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualFold("/MY4/NAME/IS/:PARAM/*", "/my4/nAME/IS/:param/*")
|
||||
AssertEqual(t, true, res)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue