🧹 housekeeping

pull/1009/head
Fenny 2020-11-11 13:54:27 +01:00
parent 015de85e30
commit 8bd50de610
14 changed files with 556 additions and 403 deletions

View File

@ -934,6 +934,33 @@ func Test_Ctx_MultipartForm(t *testing.T) {
utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code")
}
// go test -v -run=^$ -bench=Benchmark_Ctx_MultipartForm -benchmem -count=4
func Benchmark_Ctx_MultipartForm(b *testing.B) {
app := New()
app.Post("/", func(c *Ctx) error {
_, _ = c.MultipartForm()
return nil
})
c := &fasthttp.RequestCtx{}
body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
c.Request.SetBody(body)
c.Request.Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
c.Request.Header.SetContentLength(len(body))
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(c)
}
}
// go test -run Test_Ctx_OriginalURL
func Test_Ctx_OriginalURL(t *testing.T) {
t.Parallel()

View File

@ -117,6 +117,37 @@ func removeNewLines(raw string) string {
return raw
}
// removeNewLines will replace `\r` and `\n` with an empty space
func removeNewLinesBytes(raw []byte) []byte {
var (
start = 0
// strings.IndexByte is faster than bytes.IndexByte
rawStr = utils.UnsafeString(raw) // b2s()
)
// check if a `\r` is present and save the position.
// if no `\r` is found, check if a `\n` is present,
if start = strings.IndexByte(rawStr, '\r'); start == -1 {
// check if a `\n` is present if no `\r` is found
if start = strings.IndexByte(rawStr, '\n'); start == -1 {
return raw
}
}
// loop from start position to replace `\r` or `\n` with empty space
for i := start; i < len(raw); i++ {
// switch raw[i] {
// case '\r', '\n':
// raw[i] = ' '
// default:
// continue
// }
if raw[i] != '\r' && raw[i] != '\n' {
continue
}
raw[i] = ' '
}
return raw
}
// Scan stack if other methods match the request
func methodExist(ctx *Ctx) (exist bool) {
for i := 0; i < len(intMethod); i++ {

View File

@ -5,6 +5,7 @@
package fiber
import (
"bytes"
"crypto/tls"
"fmt"
"net"
@ -15,22 +16,14 @@ import (
"github.com/valyala/fasthttp"
)
// go test -v -run=^$ -bench=Benchmark_Utils_RemoveNewLines -benchmem -count=4
func Benchmark_Utils_RemoveNewLines(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_RemoveNewLines -benchmem -count=4
func Benchmark_RemoveNewLines(b *testing.B) {
withNL := "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n"
withoutNL := "foo Set-Cookie:%20SESSIONID=MaliciousValue "
expected := utils.SafeString(withoutNL)
var res string
b.Run("withNewlines", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = removeNewLines(withNL)
}
utils.AssertEqual(b, expected, res)
})
b.Run("withoutNewlines", func(b *testing.B) {
b.Run("withoutNL", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
@ -38,9 +31,81 @@ func Benchmark_Utils_RemoveNewLines(b *testing.B) {
}
utils.AssertEqual(b, expected, res)
})
b.Run("withNL", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = removeNewLines(withNL)
}
utils.AssertEqual(b, expected, res)
})
}
func Benchmark_RemoveNewLines_Bytes(b *testing.B) {
withNL := []byte("foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n")
withoutNL := []byte("foo Set-Cookie:%20SESSIONID=MaliciousValue ")
expected := []byte("foo Set-Cookie:%20SESSIONID=MaliciousValue ")
var res []byte
b.Run("withoutNL", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = removeNewLinesBytes(withoutNL)
}
utils.AssertEqual(b, true, bytes.Equal(res, expected))
})
b.Run("withNL", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = removeNewLinesBytes(withNL)
}
utils.AssertEqual(b, true, bytes.Equal(res, expected))
})
}
// go test -v -run=RemoveNewLines_Bytes -count=3
func Test_RemoveNewLines_Bytes(t *testing.T) {
app := New()
t.Run("Not Status OK", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.SendString("Hello, World!")
c.Status(201)
setETag(c, false)
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
})
t.Run("No Body", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
setETag(c, false)
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
})
t.Run("Has HeaderIfNoneMatch", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.SendString("Hello, World!")
c.Request().Header.Set(HeaderIfNoneMatch, `"13-1831710635"`)
setETag(c, false)
utils.AssertEqual(t, 304, c.Response().StatusCode())
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
utils.AssertEqual(t, "", string(c.Response().Body()))
})
t.Run("No HeaderIfNoneMatch", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.SendString("Hello, World!")
setETag(c, false)
utils.AssertEqual(t, `"13-1831710635"`, string(c.Response().Header.Peek(HeaderETag)))
})
}
// go test -v -run=Test_Utils_ -count=3
func Test_Utils_ETag(t *testing.T) {
app := New()

View File

@ -8,97 +8,10 @@ import (
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Users defines the allowed credentials
//
// Required. Default: map[string]string{}
Users map[string]string
// Realm is a string to define realm attribute of BasicAuth.
// the realm identifies the system to authenticate against
// and can be used by clients to save credentials
//
// Optional. Default: "Restricted".
Realm string
// Authorizer defines a function you can pass
// to check the credentials however you want.
// It will be called with a username and password
// and is expected to return true or false to indicate
// that the credentials were approved or not.
//
// Optional. Default: nil.
Authorizer func(string, string) bool
// Unauthorized defines the response body for unauthorized responses.
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
//
// Optional. Default: nil
Unauthorized fiber.Handler
// ContextUser is the key to store the username in Locals
//
// Optional. Default: "username"
ContextUsername string
// ContextPass is the key to store the password in Locals
//
// Optional. Default: "password"
ContextPassword string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Authorizer: nil,
Unauthorized: nil,
ContextUsername: "username",
ContextPassword: "password",
}
// New creates a new middleware handler
func New(config Config) fiber.Handler {
cfg := config
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Users == nil {
cfg.Users = ConfigDefault.Users
}
if cfg.Realm == "" {
cfg.Realm = ConfigDefault.Realm
}
if cfg.Authorizer == nil {
cfg.Authorizer = func(user, pass string) bool {
user, exist := cfg.Users[user]
if !exist {
return false
}
return user == pass
}
}
if cfg.Unauthorized == nil {
cfg.Unauthorized = func(c *fiber.Ctx) error {
c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm)
return c.SendStatus(fiber.StatusUnauthorized)
}
}
if cfg.ContextUsername == "" {
cfg.ContextUsername = ConfigDefault.ContextUsername
}
if cfg.ContextPassword == "" {
cfg.ContextPassword = ConfigDefault.ContextPassword
}
// Set default config
cfg := configDefault(config)
// Return new handler
return func(c *fiber.Ctx) error {

View File

@ -0,0 +1,105 @@
package basicauth
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Users defines the allowed credentials
//
// Required. Default: map[string]string{}
Users map[string]string
// Realm is a string to define realm attribute of BasicAuth.
// the realm identifies the system to authenticate against
// and can be used by clients to save credentials
//
// Optional. Default: "Restricted".
Realm string
// Authorizer defines a function you can pass
// to check the credentials however you want.
// It will be called with a username and password
// and is expected to return true or false to indicate
// that the credentials were approved or not.
//
// Optional. Default: nil.
Authorizer func(string, string) bool
// Unauthorized defines the response body for unauthorized responses.
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
//
// Optional. Default: nil
Unauthorized fiber.Handler
// ContextUser is the key to store the username in Locals
//
// Optional. Default: "username"
ContextUsername string
// ContextPass is the key to store the password in Locals
//
// Optional. Default: "password"
ContextPassword string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Authorizer: nil,
Unauthorized: nil,
ContextUsername: "username",
ContextPassword: "password",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Users == nil {
cfg.Users = ConfigDefault.Users
}
if cfg.Realm == "" {
cfg.Realm = ConfigDefault.Realm
}
if cfg.Authorizer == nil {
cfg.Authorizer = func(user, pass string) bool {
user, exist := cfg.Users[user]
if !exist {
return false
}
return user == pass
}
}
if cfg.Unauthorized == nil {
cfg.Unauthorized = func(c *fiber.Ctx) error {
c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm)
return c.SendStatus(fiber.StatusUnauthorized)
}
}
if cfg.ContextUsername == "" {
cfg.ContextUsername = ConfigDefault.ContextUsername
}
if cfg.ContextPassword == "" {
cfg.ContextPassword = ConfigDefault.ContextPassword
}
return cfg
}

View File

@ -12,74 +12,10 @@ import (
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Expiration is the time that an cached response will live
//
// Optional. Default: 1 * time.Minute
Expiration time.Duration
// CacheControl enables client side caching if set to true
//
// Optional. Default: false
CacheControl bool
// Key allows you to generate custom keys, by default c.Path() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.Path()
// }
Key func(*fiber.Ctx) string
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store fiber.Storage
// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
defaultStore bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
Key: func(c *fiber.Ctx) string {
return c.Path()
},
defaultStore: true,
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.Key == nil {
cfg.Key = ConfigDefault.Key
}
if cfg.Store == nil {
cfg.defaultStore = true
}
}
cfg := configDefault(config...)
var (
// Cache settings
@ -152,7 +88,7 @@ func New(config ...Config) fiber.Handler {
} else {
// Load data from store
storeEntry, err := cfg.Store.Get(key)
storeEntry, err := cfg.Storage.Get(key)
if err != nil {
return err
}
@ -165,7 +101,7 @@ func New(config ...Config) fiber.Handler {
}
}
if entryBody, err = cfg.Store.Get(key + "_body"); err != nil {
if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil {
return err
}
}
@ -183,10 +119,10 @@ func New(config ...Config) fiber.Handler {
if cfg.defaultStore {
delete(entries, key)
} else { // Use custom storage
if err := cfg.Store.Delete(key); err != nil {
if err := cfg.Storage.Delete(key); err != nil {
return err
}
if err := cfg.Store.Delete(key + "_body"); err != nil {
if err := cfg.Storage.Delete(key + "_body"); err != nil {
return err
}
}
@ -234,12 +170,12 @@ func New(config ...Config) fiber.Handler {
}
// Pass bytes to Storage
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil {
return err
}
// Pass bytes to Storage
if err = cfg.Store.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
return err
}
}

87
middleware/cache/config.go vendored Normal file
View File

@ -0,0 +1,87 @@
package cache
import (
"fmt"
"time"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Expiration is the time that an cached response will live
//
// Optional. Default: 1 * time.Minute
Expiration time.Duration
// CacheControl enables client side caching if set to true
//
// Optional. Default: false
CacheControl bool
// Key allows you to generate custom keys, by default c.Path() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.Path()
// }
Key func(*fiber.Ctx) string
// Deprecated, use Storage instead
Store fiber.Storage
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
defaultStore bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
Key: func(c *fiber.Ctx) string {
return c.Path()
},
defaultStore: true,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.Key == nil {
cfg.Key = ConfigDefault.Key
}
if cfg.Storage == nil && cfg.Store == nil {
cfg.defaultStore = true
}
if cfg.Store != nil {
fmt.Println("cache: `Store` is deprecated, use `Storage` instead")
cfg.Storage = cfg.Store
cfg.defaultStore = true
}
return cfg
}

View File

@ -5,54 +5,10 @@ import (
"github.com/valyala/fasthttp"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Level determines the compression algorithm
//
// Optional. Default: LevelDefault
// LevelDisabled: -1
// LevelDefault: 0
// LevelBestSpeed: 1
// LevelBestCompression: 2
Level Level
}
// Level is numeric representation of compression level
type Level int
// Represents compression level that will be used in the middleware
const (
LevelDisabled Level = -1
LevelDefault Level = 0
LevelBestSpeed Level = 1
LevelBestCompression Level = 2
)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
cfg.Level = ConfigDefault.Level
}
}
cfg := configDefault(config...)
// Setup request handlers
var (

View File

@ -0,0 +1,56 @@
package compress
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Level determines the compression algorithm
//
// Optional. Default: LevelDefault
// LevelDisabled: -1
// LevelDefault: 0
// LevelBestSpeed: 1
// LevelBestCompression: 2
Level Level
}
// Level is numeric representation of compression level
type Level int
// Represents compression level that will be used in the middleware
const (
LevelDisabled Level = -1
LevelDefault Level = 0
LevelBestSpeed Level = 1
LevelBestCompression Level = 2
)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
cfg.Level = ConfigDefault.Level
}
return cfg
}

View File

@ -40,9 +40,14 @@ type Config struct {
// Expiration is the duration before csrf token will expire
//
// Optional. Default: 24 * time.Hour
// Optional. Default: 1 * time.Hour
Expiration time.Duration
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// Context key to store generated CSRF token into context.
//
// Optional. Default value "csrf".
@ -58,8 +63,8 @@ var ConfigDefault = Config{
Name: "_csrf",
SameSite: "Strict",
},
Expiration: 24 * time.Hour,
CookieExpires: 24 * time.Hour, // deprecated
Expiration: 1 * time.Hour,
CookieExpires: 1 * time.Hour, // deprecated
}
type storage struct {

View File

@ -0,0 +1,103 @@
package limiter
import (
"fmt"
"time"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Max number of recent connections during `Duration` seconds before sending a 429 response
//
// Default: 5
Max int
// Key allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.IP()
// }
Key func(*fiber.Ctx) string
// Expiration is the time on how long to keep records of requests in memory
//
// Default: 1 * time.Minute
Expiration time.Duration
// LimitReached is called when a request hits the limit
//
// Default: func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// DEPRECATED: Use Expiration instead
Duration time.Duration
// DEPRECATED, use Storage instead
Store fiber.Storage
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Max: 5,
Expiration: 1 * time.Minute,
Key: func(c *fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Max <= 0 {
cfg.Max = ConfigDefault.Max
}
if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if int(cfg.Duration.Seconds()) > 0 {
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
if cfg.Expiration != ConfigDefault.Expiration {
cfg.Expiration = cfg.Duration
}
}
if cfg.Key == nil {
cfg.Key = ConfigDefault.Key
}
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.Store != nil {
fmt.Println("[LIMITER] Store is deprecated, please use Storage")
cfg.Storage = cfg.Store
}
return cfg
}

View File

@ -10,69 +10,11 @@ import (
"github.com/gofiber/fiber/v2"
)
//go:generate msgp -unexported
//msgp:ignore Config
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Max number of recent connections during `Duration` seconds before sending a 429 response
//
// Default: 5
Max int
// DEPRECATED: Use Expiration instead
Duration time.Duration
// Expiration is the time on how long to keep records of requests in memory
//
// Default: 1 * time.Minute
Expiration time.Duration
// Key allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.IP()
// }
Key func(*fiber.Ctx) string
// LimitReached is called when a request hits the limit
//
// Default: func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store fiber.Storage
// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
defaultStore bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Max: 5,
Expiration: 1 * time.Minute,
Key: func(c *fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
defaultStore: true,
}
// X-RateLimit-* headers
const (
// Storage ErrNotExist
errNotExist = "key does not exist"
// X-RateLimit-* headers
xRateLimitLimit = "X-RateLimit-Limit"
xRateLimitRemaining = "X-RateLimit-Remaining"
xRateLimitReset = "X-RateLimit-Reset"
@ -81,38 +23,7 @@ const (
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Max <= 0 {
cfg.Max = ConfigDefault.Max
}
if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if int(cfg.Duration.Seconds()) > 0 {
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
if cfg.Expiration != ConfigDefault.Expiration {
cfg.Expiration = cfg.Duration
}
}
if cfg.Key == nil {
cfg.Key = ConfigDefault.Key
}
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.Store == nil {
cfg.defaultStore = true
}
}
cfg := configDefault(config...)
var (
// Limiter settings
@ -150,21 +61,19 @@ func New(config ...Config) fiber.Handler {
mux.Lock()
defer mux.Unlock()
// Use default memory storage
if cfg.defaultStore {
entry = entries[key]
} else { // Use custom storage
storeEntry, err := cfg.Store.Get(key)
if err != nil {
return err
}
// Only decode if we found an entry
if storeEntry != nil {
// Decode bytes using msgp
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
// Use Storage if provided
if cfg.Storage != nil {
val, err := cfg.Storage.Get(key)
if val != nil && len(val) > 0 {
if _, err := entry.UnmarshalMsg(val); err != nil {
return err
}
}
if err != nil && err.Error() != errNotExist {
fmt.Println("[LIMITER]", err.Error())
}
} else {
entry = entries[key]
}
// Get timestamp
@ -183,19 +92,20 @@ func New(config ...Config) fiber.Handler {
// Increment hits
entry.hits++
// Use default memory storage
if cfg.defaultStore {
entries[key] = entry
} else { // Use custom storage
data, err := entry.MarshalMsg(nil)
// Use Storage if provided
if cfg.Storage != nil {
// Marshal entry to bytes
val, err := entry.MarshalMsg(nil)
if err != nil {
return err
}
// Pass bytes to Storage
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
// Pass value to Storage
if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
return err
}
} else {
entries[key] = entry
}
// Calculate when it resets in seconds
@ -223,10 +133,3 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
}
// replacer for strconv.FormatUint
// func appendInt(buf *bytebufferpool.ByteBuffer, v int) (int, error) {
// old := len(buf.B)
// buf.B = fasthttp.AppendUint(buf.B, v)
// return len(buf.B) - old, nil
// }

View File

@ -9,9 +9,9 @@ import (
"testing"
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -24,7 +24,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
Storage: memory.New(),
}))
app.Get("/", func(c *fiber.Ctx) error {
@ -108,32 +108,6 @@ func Test_Limiter_Concurrency(t *testing.T) {
}
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
func Benchmark_Limiter(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
func Benchmark_Limiter_Custom_Store(b *testing.B) {
app := fiber.New()
@ -141,7 +115,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) {
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
Storage: memory.New(),
}))
app.Get("/", func(c *fiber.Ctx) error {
@ -203,39 +177,28 @@ func Test_Limiter_Headers(t *testing.T) {
}
}
// testStore is used for testing custom stores
type testStore struct {
stmap map[string][]byte
mutex *sync.Mutex
}
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
func Benchmark_Limiter(b *testing.B) {
app := fiber.New()
func (s testStore) Get(id string) ([]byte, error) {
s.mutex.Lock()
val, ok := s.stmap[id]
s.mutex.Unlock()
if !ok {
return nil, nil
} else {
return val, nil
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("GET")
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()
return nil
}
func (s testStore) Reset() error {
return nil
}
func (s testStore) Delete(id string) error {
return nil
}
func (s testStore) Close() error {
return nil
}

View File

@ -186,6 +186,7 @@ func New(config ...Config) fiber.Handler {
var (
start, stop time.Time
once sync.Once
mu sync.Mutex
errHandler fiber.ErrorHandler
)
@ -362,6 +363,7 @@ func New(config ...Config) fiber.Handler {
if err != nil {
_, _ = buf.WriteString(err.Error())
}
mu.Lock()
// Write buffer to output
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
// Write error to output
@ -370,6 +372,7 @@ func New(config ...Config) fiber.Handler {
// TODO: What should we do here?
}
}
mu.Unlock()
// Put buffer back to pool
bytebufferpool.Put(buf)