mirror of https://github.com/gofiber/fiber.git
🧹 housekeeping
parent
015de85e30
commit
8bd50de610
27
ctx_test.go
27
ctx_test.go
|
@ -934,6 +934,33 @@ func Test_Ctx_MultipartForm(t *testing.T) {
|
|||
utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Ctx_MultipartForm -benchmem -count=4
|
||||
func Benchmark_Ctx_MultipartForm(b *testing.B) {
|
||||
app := New()
|
||||
|
||||
app.Post("/", func(c *Ctx) error {
|
||||
_, _ = c.MultipartForm()
|
||||
return nil
|
||||
})
|
||||
|
||||
c := &fasthttp.RequestCtx{}
|
||||
|
||||
body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
|
||||
c.Request.SetBody(body)
|
||||
c.Request.Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
|
||||
c.Request.Header.SetContentLength(len(body))
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(c)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// go test -run Test_Ctx_OriginalURL
|
||||
func Test_Ctx_OriginalURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
31
helpers.go
31
helpers.go
|
@ -117,6 +117,37 @@ func removeNewLines(raw string) string {
|
|||
return raw
|
||||
}
|
||||
|
||||
// removeNewLines will replace `\r` and `\n` with an empty space
|
||||
func removeNewLinesBytes(raw []byte) []byte {
|
||||
var (
|
||||
start = 0
|
||||
// strings.IndexByte is faster than bytes.IndexByte
|
||||
rawStr = utils.UnsafeString(raw) // b2s()
|
||||
)
|
||||
// check if a `\r` is present and save the position.
|
||||
// if no `\r` is found, check if a `\n` is present,
|
||||
if start = strings.IndexByte(rawStr, '\r'); start == -1 {
|
||||
// check if a `\n` is present if no `\r` is found
|
||||
if start = strings.IndexByte(rawStr, '\n'); start == -1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
// loop from start position to replace `\r` or `\n` with empty space
|
||||
for i := start; i < len(raw); i++ {
|
||||
// switch raw[i] {
|
||||
// case '\r', '\n':
|
||||
// raw[i] = ' '
|
||||
// default:
|
||||
// continue
|
||||
// }
|
||||
if raw[i] != '\r' && raw[i] != '\n' {
|
||||
continue
|
||||
}
|
||||
raw[i] = ' '
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// Scan stack if other methods match the request
|
||||
func methodExist(ctx *Ctx) (exist bool) {
|
||||
for i := 0; i < len(intMethod); i++ {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package fiber
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -15,22 +16,14 @@ import (
|
|||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Utils_RemoveNewLines -benchmem -count=4
|
||||
func Benchmark_Utils_RemoveNewLines(b *testing.B) {
|
||||
// go test -v -run=^$ -bench=Benchmark_RemoveNewLines -benchmem -count=4
|
||||
func Benchmark_RemoveNewLines(b *testing.B) {
|
||||
withNL := "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n"
|
||||
withoutNL := "foo Set-Cookie:%20SESSIONID=MaliciousValue "
|
||||
expected := utils.SafeString(withoutNL)
|
||||
var res string
|
||||
|
||||
b.Run("withNewlines", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = removeNewLines(withNL)
|
||||
}
|
||||
utils.AssertEqual(b, expected, res)
|
||||
})
|
||||
b.Run("withoutNewlines", func(b *testing.B) {
|
||||
b.Run("withoutNL", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
|
@ -38,9 +31,81 @@ func Benchmark_Utils_RemoveNewLines(b *testing.B) {
|
|||
}
|
||||
utils.AssertEqual(b, expected, res)
|
||||
})
|
||||
|
||||
b.Run("withNL", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = removeNewLines(withNL)
|
||||
}
|
||||
utils.AssertEqual(b, expected, res)
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark_RemoveNewLines_Bytes(b *testing.B) {
|
||||
withNL := []byte("foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n")
|
||||
withoutNL := []byte("foo Set-Cookie:%20SESSIONID=MaliciousValue ")
|
||||
expected := []byte("foo Set-Cookie:%20SESSIONID=MaliciousValue ")
|
||||
var res []byte
|
||||
|
||||
b.Run("withoutNL", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = removeNewLinesBytes(withoutNL)
|
||||
}
|
||||
utils.AssertEqual(b, true, bytes.Equal(res, expected))
|
||||
})
|
||||
|
||||
b.Run("withNL", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = removeNewLinesBytes(withNL)
|
||||
}
|
||||
utils.AssertEqual(b, true, bytes.Equal(res, expected))
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=RemoveNewLines_Bytes -count=3
|
||||
func Test_RemoveNewLines_Bytes(t *testing.T) {
|
||||
app := New()
|
||||
t.Run("Not Status OK", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.SendString("Hello, World!")
|
||||
c.Status(201)
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
|
||||
})
|
||||
|
||||
t.Run("No Body", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
|
||||
})
|
||||
|
||||
t.Run("Has HeaderIfNoneMatch", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.SendString("Hello, World!")
|
||||
c.Request().Header.Set(HeaderIfNoneMatch, `"13-1831710635"`)
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, 304, c.Response().StatusCode())
|
||||
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
|
||||
utils.AssertEqual(t, "", string(c.Response().Body()))
|
||||
})
|
||||
|
||||
t.Run("No HeaderIfNoneMatch", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.SendString("Hello, World!")
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, `"13-1831710635"`, string(c.Response().Header.Peek(HeaderETag)))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// go test -v -run=Test_Utils_ -count=3
|
||||
func Test_Utils_ETag(t *testing.T) {
|
||||
app := New()
|
||||
|
|
|
@ -8,97 +8,10 @@ import (
|
|||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Users defines the allowed credentials
|
||||
//
|
||||
// Required. Default: map[string]string{}
|
||||
Users map[string]string
|
||||
|
||||
// Realm is a string to define realm attribute of BasicAuth.
|
||||
// the realm identifies the system to authenticate against
|
||||
// and can be used by clients to save credentials
|
||||
//
|
||||
// Optional. Default: "Restricted".
|
||||
Realm string
|
||||
|
||||
// Authorizer defines a function you can pass
|
||||
// to check the credentials however you want.
|
||||
// It will be called with a username and password
|
||||
// and is expected to return true or false to indicate
|
||||
// that the credentials were approved or not.
|
||||
//
|
||||
// Optional. Default: nil.
|
||||
Authorizer func(string, string) bool
|
||||
|
||||
// Unauthorized defines the response body for unauthorized responses.
|
||||
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Unauthorized fiber.Handler
|
||||
|
||||
// ContextUser is the key to store the username in Locals
|
||||
//
|
||||
// Optional. Default: "username"
|
||||
ContextUsername string
|
||||
|
||||
// ContextPass is the key to store the password in Locals
|
||||
//
|
||||
// Optional. Default: "password"
|
||||
ContextPassword string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Users: map[string]string{},
|
||||
Realm: "Restricted",
|
||||
Authorizer: nil,
|
||||
Unauthorized: nil,
|
||||
ContextUsername: "username",
|
||||
ContextPassword: "password",
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config Config) fiber.Handler {
|
||||
cfg := config
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Users == nil {
|
||||
cfg.Users = ConfigDefault.Users
|
||||
}
|
||||
if cfg.Realm == "" {
|
||||
cfg.Realm = ConfigDefault.Realm
|
||||
}
|
||||
if cfg.Authorizer == nil {
|
||||
cfg.Authorizer = func(user, pass string) bool {
|
||||
user, exist := cfg.Users[user]
|
||||
if !exist {
|
||||
return false
|
||||
}
|
||||
return user == pass
|
||||
}
|
||||
}
|
||||
if cfg.Unauthorized == nil {
|
||||
cfg.Unauthorized = func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm)
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
if cfg.ContextUsername == "" {
|
||||
cfg.ContextUsername = ConfigDefault.ContextUsername
|
||||
}
|
||||
if cfg.ContextPassword == "" {
|
||||
cfg.ContextPassword = ConfigDefault.ContextPassword
|
||||
}
|
||||
// Set default config
|
||||
cfg := configDefault(config)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Users defines the allowed credentials
|
||||
//
|
||||
// Required. Default: map[string]string{}
|
||||
Users map[string]string
|
||||
|
||||
// Realm is a string to define realm attribute of BasicAuth.
|
||||
// the realm identifies the system to authenticate against
|
||||
// and can be used by clients to save credentials
|
||||
//
|
||||
// Optional. Default: "Restricted".
|
||||
Realm string
|
||||
|
||||
// Authorizer defines a function you can pass
|
||||
// to check the credentials however you want.
|
||||
// It will be called with a username and password
|
||||
// and is expected to return true or false to indicate
|
||||
// that the credentials were approved or not.
|
||||
//
|
||||
// Optional. Default: nil.
|
||||
Authorizer func(string, string) bool
|
||||
|
||||
// Unauthorized defines the response body for unauthorized responses.
|
||||
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Unauthorized fiber.Handler
|
||||
|
||||
// ContextUser is the key to store the username in Locals
|
||||
//
|
||||
// Optional. Default: "username"
|
||||
ContextUsername string
|
||||
|
||||
// ContextPass is the key to store the password in Locals
|
||||
//
|
||||
// Optional. Default: "password"
|
||||
ContextPassword string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Users: map[string]string{},
|
||||
Realm: "Restricted",
|
||||
Authorizer: nil,
|
||||
Unauthorized: nil,
|
||||
ContextUsername: "username",
|
||||
ContextPassword: "password",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Users == nil {
|
||||
cfg.Users = ConfigDefault.Users
|
||||
}
|
||||
if cfg.Realm == "" {
|
||||
cfg.Realm = ConfigDefault.Realm
|
||||
}
|
||||
if cfg.Authorizer == nil {
|
||||
cfg.Authorizer = func(user, pass string) bool {
|
||||
user, exist := cfg.Users[user]
|
||||
if !exist {
|
||||
return false
|
||||
}
|
||||
return user == pass
|
||||
}
|
||||
}
|
||||
if cfg.Unauthorized == nil {
|
||||
cfg.Unauthorized = func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm)
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
if cfg.ContextUsername == "" {
|
||||
cfg.ContextUsername = ConfigDefault.ContextUsername
|
||||
}
|
||||
if cfg.ContextPassword == "" {
|
||||
cfg.ContextPassword = ConfigDefault.ContextPassword
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -12,74 +12,10 @@ import (
|
|||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Expiration is the time that an cached response will live
|
||||
//
|
||||
// Optional. Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// CacheControl enables client side caching if set to true
|
||||
//
|
||||
// Optional. Default: false
|
||||
CacheControl bool
|
||||
|
||||
// Key allows you to generate custom keys, by default c.Path() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.Path()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Store fiber.Storage
|
||||
|
||||
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||
// execution time down.
|
||||
defaultStore bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheControl: false,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
return c.Path()
|
||||
},
|
||||
defaultStore: true,
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) == 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.Key == nil {
|
||||
cfg.Key = ConfigDefault.Key
|
||||
}
|
||||
if cfg.Store == nil {
|
||||
cfg.defaultStore = true
|
||||
}
|
||||
}
|
||||
cfg := configDefault(config...)
|
||||
|
||||
var (
|
||||
// Cache settings
|
||||
|
@ -152,7 +88,7 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
} else {
|
||||
// Load data from store
|
||||
storeEntry, err := cfg.Store.Get(key)
|
||||
storeEntry, err := cfg.Storage.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -165,7 +101,7 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
if entryBody, err = cfg.Store.Get(key + "_body"); err != nil {
|
||||
if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -183,10 +119,10 @@ func New(config ...Config) fiber.Handler {
|
|||
if cfg.defaultStore {
|
||||
delete(entries, key)
|
||||
} else { // Use custom storage
|
||||
if err := cfg.Store.Delete(key); err != nil {
|
||||
if err := cfg.Storage.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cfg.Store.Delete(key + "_body"); err != nil {
|
||||
if err := cfg.Storage.Delete(key + "_body"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -234,12 +170,12 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
|
||||
if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Store.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
|
||||
if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Expiration is the time that an cached response will live
|
||||
//
|
||||
// Optional. Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// CacheControl enables client side caching if set to true
|
||||
//
|
||||
// Optional. Default: false
|
||||
CacheControl bool
|
||||
|
||||
// Key allows you to generate custom keys, by default c.Path() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.Path()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// Deprecated, use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||
// execution time down.
|
||||
defaultStore bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheControl: false,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
return c.Path()
|
||||
},
|
||||
defaultStore: true,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) == 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.Key == nil {
|
||||
cfg.Key = ConfigDefault.Key
|
||||
}
|
||||
if cfg.Storage == nil && cfg.Store == nil {
|
||||
cfg.defaultStore = true
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("cache: `Store` is deprecated, use `Storage` instead")
|
||||
cfg.Storage = cfg.Store
|
||||
cfg.defaultStore = true
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -5,54 +5,10 @@ import (
|
|||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Level determines the compression algorithm
|
||||
//
|
||||
// Optional. Default: LevelDefault
|
||||
// LevelDisabled: -1
|
||||
// LevelDefault: 0
|
||||
// LevelBestSpeed: 1
|
||||
// LevelBestCompression: 2
|
||||
Level Level
|
||||
}
|
||||
|
||||
// Level is numeric representation of compression level
|
||||
type Level int
|
||||
|
||||
// Represents compression level that will be used in the middleware
|
||||
const (
|
||||
LevelDisabled Level = -1
|
||||
LevelDefault Level = 0
|
||||
LevelBestSpeed Level = 1
|
||||
LevelBestCompression Level = 2
|
||||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Level: LevelDefault,
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
|
||||
cfg.Level = ConfigDefault.Level
|
||||
}
|
||||
}
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Setup request handlers
|
||||
var (
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Level determines the compression algorithm
|
||||
//
|
||||
// Optional. Default: LevelDefault
|
||||
// LevelDisabled: -1
|
||||
// LevelDefault: 0
|
||||
// LevelBestSpeed: 1
|
||||
// LevelBestCompression: 2
|
||||
Level Level
|
||||
}
|
||||
|
||||
// Level is numeric representation of compression level
|
||||
type Level int
|
||||
|
||||
// Represents compression level that will be used in the middleware
|
||||
const (
|
||||
LevelDisabled Level = -1
|
||||
LevelDefault Level = 0
|
||||
LevelBestSpeed Level = 1
|
||||
LevelBestCompression Level = 2
|
||||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Level: LevelDefault,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
|
||||
cfg.Level = ConfigDefault.Level
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -40,9 +40,14 @@ type Config struct {
|
|||
|
||||
// Expiration is the duration before csrf token will expire
|
||||
//
|
||||
// Optional. Default: 24 * time.Hour
|
||||
// Optional. Default: 1 * time.Hour
|
||||
Expiration time.Duration
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// Context key to store generated CSRF token into context.
|
||||
//
|
||||
// Optional. Default value "csrf".
|
||||
|
@ -58,8 +63,8 @@ var ConfigDefault = Config{
|
|||
Name: "_csrf",
|
||||
SameSite: "Strict",
|
||||
},
|
||||
Expiration: 24 * time.Hour,
|
||||
CookieExpires: 24 * time.Hour, // deprecated
|
||||
Expiration: 1 * time.Hour,
|
||||
CookieExpires: 1 * time.Hour, // deprecated
|
||||
}
|
||||
|
||||
type storage struct {
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Max number of recent connections during `Duration` seconds before sending a 429 response
|
||||
//
|
||||
// Default: 5
|
||||
Max int
|
||||
|
||||
// Key allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.IP()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// Expiration is the time on how long to keep records of requests in memory
|
||||
//
|
||||
// Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// LimitReached is called when a request hits the limit
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// DEPRECATED: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// DEPRECATED, use Storage instead
|
||||
Store fiber.Storage
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Max <= 0 {
|
||||
cfg.Max = ConfigDefault.Max
|
||||
}
|
||||
if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
if cfg.Expiration != ConfigDefault.Expiration {
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
}
|
||||
if cfg.Key == nil {
|
||||
cfg.Key = ConfigDefault.Key
|
||||
}
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("[LIMITER] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -10,69 +10,11 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
//go:generate msgp -unexported
|
||||
//msgp:ignore Config
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Max number of recent connections during `Duration` seconds before sending a 429 response
|
||||
//
|
||||
// Default: 5
|
||||
Max int
|
||||
|
||||
// DEPRECATED: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// Expiration is the time on how long to keep records of requests in memory
|
||||
//
|
||||
// Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// Key allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.IP()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// LimitReached is called when a request hits the limit
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Store fiber.Storage
|
||||
|
||||
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||
// execution time down.
|
||||
defaultStore bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
defaultStore: true,
|
||||
}
|
||||
|
||||
// X-RateLimit-* headers
|
||||
const (
|
||||
// Storage ErrNotExist
|
||||
errNotExist = "key does not exist"
|
||||
|
||||
// X-RateLimit-* headers
|
||||
xRateLimitLimit = "X-RateLimit-Limit"
|
||||
xRateLimitRemaining = "X-RateLimit-Remaining"
|
||||
xRateLimitReset = "X-RateLimit-Reset"
|
||||
|
@ -81,38 +23,7 @@ const (
|
|||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Max <= 0 {
|
||||
cfg.Max = ConfigDefault.Max
|
||||
}
|
||||
if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
fmt.Println("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
if cfg.Expiration != ConfigDefault.Expiration {
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
}
|
||||
if cfg.Key == nil {
|
||||
cfg.Key = ConfigDefault.Key
|
||||
}
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.Store == nil {
|
||||
cfg.defaultStore = true
|
||||
}
|
||||
}
|
||||
cfg := configDefault(config...)
|
||||
|
||||
var (
|
||||
// Limiter settings
|
||||
|
@ -150,21 +61,19 @@ func New(config ...Config) fiber.Handler {
|
|||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
// Use default memory storage
|
||||
if cfg.defaultStore {
|
||||
entry = entries[key]
|
||||
} else { // Use custom storage
|
||||
storeEntry, err := cfg.Store.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Only decode if we found an entry
|
||||
if storeEntry != nil {
|
||||
// Decode bytes using msgp
|
||||
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
|
||||
// Use Storage if provided
|
||||
if cfg.Storage != nil {
|
||||
val, err := cfg.Storage.Get(key)
|
||||
if val != nil && len(val) > 0 {
|
||||
if _, err := entry.UnmarshalMsg(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err != nil && err.Error() != errNotExist {
|
||||
fmt.Println("[LIMITER]", err.Error())
|
||||
}
|
||||
} else {
|
||||
entry = entries[key]
|
||||
}
|
||||
|
||||
// Get timestamp
|
||||
|
@ -183,19 +92,20 @@ func New(config ...Config) fiber.Handler {
|
|||
// Increment hits
|
||||
entry.hits++
|
||||
|
||||
// Use default memory storage
|
||||
if cfg.defaultStore {
|
||||
entries[key] = entry
|
||||
} else { // Use custom storage
|
||||
data, err := entry.MarshalMsg(nil)
|
||||
// Use Storage if provided
|
||||
if cfg.Storage != nil {
|
||||
// Marshal entry to bytes
|
||||
val, err := entry.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
|
||||
// Pass value to Storage
|
||||
if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
entries[key] = entry
|
||||
}
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
|
@ -223,10 +133,3 @@ func New(config ...Config) fiber.Handler {
|
|||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// replacer for strconv.FormatUint
|
||||
// func appendInt(buf *bytebufferpool.ByteBuffer, v int) (int, error) {
|
||||
// old := len(buf.B)
|
||||
// buf.B = fasthttp.AppendUint(buf.B, v)
|
||||
// return len(buf.B) - old, nil
|
||||
// }
|
||||
|
|
|
@ -9,9 +9,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) {
|
|||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
|
@ -108,32 +108,6 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
||||
func Benchmark_Limiter(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
|
||||
func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
@ -141,7 +115,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
|||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
|
@ -203,39 +177,28 @@ func Test_Limiter_Headers(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// testStore is used for testing custom stores
|
||||
type testStore struct {
|
||||
stmap map[string][]byte
|
||||
mutex *sync.Mutex
|
||||
}
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
||||
func Benchmark_Limiter(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
func (s testStore) Get(id string) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
val, ok := s.stmap[id]
|
||||
s.mutex.Unlock()
|
||||
if !ok {
|
||||
return nil, nil
|
||||
} else {
|
||||
return val, nil
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
s.stmap[id] = val
|
||||
s.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -186,6 +186,7 @@ func New(config ...Config) fiber.Handler {
|
|||
var (
|
||||
start, stop time.Time
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
errHandler fiber.ErrorHandler
|
||||
)
|
||||
|
||||
|
@ -362,6 +363,7 @@ func New(config ...Config) fiber.Handler {
|
|||
if err != nil {
|
||||
_, _ = buf.WriteString(err.Error())
|
||||
}
|
||||
mu.Lock()
|
||||
// Write buffer to output
|
||||
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
|
||||
// Write error to output
|
||||
|
@ -370,6 +372,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// TODO: What should we do here?
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
// Put buffer back to pool
|
||||
bytebufferpool.Put(buf)
|
||||
|
||||
|
|
Loading…
Reference in New Issue