From 8bd50de610daea2c43eb8eb1aa8d24c47e255add Mon Sep 17 00:00:00 2001 From: Fenny <25108519+Fenny@users.noreply.github.com> Date: Wed, 11 Nov 2020 13:54:27 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9=20housekeeping?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ctx_test.go | 27 ++++++ helpers.go | 31 +++++++ helpers_test.go | 89 +++++++++++++++--- middleware/basicauth/basicauth.go | 91 +----------------- middleware/basicauth/config.go | 105 +++++++++++++++++++++ middleware/cache/cache.go | 78 ++-------------- middleware/cache/config.go | 87 ++++++++++++++++++ middleware/compress/compress.go | 46 +--------- middleware/compress/config.go | 56 +++++++++++ middleware/csrf/csrf.go | 11 ++- middleware/limiter/config.go | 103 +++++++++++++++++++++ middleware/limiter/limiter.go | 143 +++++------------------------ middleware/limiter/limiter_test.go | 89 ++++++------------ middleware/logger/logger.go | 3 + 14 files changed, 556 insertions(+), 403 deletions(-) create mode 100644 middleware/basicauth/config.go create mode 100644 middleware/cache/config.go create mode 100644 middleware/compress/config.go create mode 100644 middleware/limiter/config.go diff --git a/ctx_test.go b/ctx_test.go index c0f1c497..cd22f51b 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -934,6 +934,33 @@ func Test_Ctx_MultipartForm(t *testing.T) { utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") } +// go test -v -run=^$ -bench=Benchmark_Ctx_MultipartForm -benchmem -count=4 +func Benchmark_Ctx_MultipartForm(b *testing.B) { + app := New() + + app.Post("/", func(c *Ctx) error { + _, _ = c.MultipartForm() + return nil + }) + + c := &fasthttp.RequestCtx{} + + body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--") + c.Request.SetBody(body) + c.Request.Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) + c.Request.Header.SetContentLength(len(body)) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + h(c) + } + +} + // go test -run Test_Ctx_OriginalURL func Test_Ctx_OriginalURL(t *testing.T) { t.Parallel() diff --git a/helpers.go b/helpers.go index 8c53cbf4..40d4f590 100644 --- a/helpers.go +++ b/helpers.go @@ -117,6 +117,37 @@ func removeNewLines(raw string) string { return raw } +// removeNewLines will replace `\r` and `\n` with an empty space +func removeNewLinesBytes(raw []byte) []byte { + var ( + start = 0 + // strings.IndexByte is faster than bytes.IndexByte + rawStr = utils.UnsafeString(raw) // b2s() + ) + // check if a `\r` is present and save the position. + // if no `\r` is found, check if a `\n` is present, + if start = strings.IndexByte(rawStr, '\r'); start == -1 { + // check if a `\n` is present if no `\r` is found + if start = strings.IndexByte(rawStr, '\n'); start == -1 { + return raw + } + } + // loop from start position to replace `\r` or `\n` with empty space + for i := start; i < len(raw); i++ { + // switch raw[i] { + // case '\r', '\n': + // raw[i] = ' ' + // default: + // continue + // } + if raw[i] != '\r' && raw[i] != '\n' { + continue + } + raw[i] = ' ' + } + return raw +} + // Scan stack if other methods match the request func methodExist(ctx *Ctx) (exist bool) { for i := 0; i < len(intMethod); i++ { diff --git a/helpers_test.go b/helpers_test.go index 5b198694..c73f6040 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -5,6 +5,7 @@ package fiber import ( + "bytes" "crypto/tls" "fmt" "net" @@ -15,22 +16,14 @@ import ( "github.com/valyala/fasthttp" ) -// go test -v -run=^$ -bench=Benchmark_Utils_RemoveNewLines -benchmem -count=4 -func Benchmark_Utils_RemoveNewLines(b *testing.B) { +// go test -v -run=^$ -bench=Benchmark_RemoveNewLines -benchmem -count=4 +func Benchmark_RemoveNewLines(b *testing.B) { withNL := "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n" withoutNL := "foo Set-Cookie:%20SESSIONID=MaliciousValue " expected := utils.SafeString(withoutNL) var res string - b.Run("withNewlines", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = removeNewLines(withNL) - } - utils.AssertEqual(b, expected, res) - }) - b.Run("withoutNewlines", func(b *testing.B) { + b.Run("withoutNL", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { @@ -38,9 +31,81 @@ func Benchmark_Utils_RemoveNewLines(b *testing.B) { } utils.AssertEqual(b, expected, res) }) - + b.Run("withNL", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = removeNewLines(withNL) + } + utils.AssertEqual(b, expected, res) + }) } +func Benchmark_RemoveNewLines_Bytes(b *testing.B) { + withNL := []byte("foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n") + withoutNL := []byte("foo Set-Cookie:%20SESSIONID=MaliciousValue ") + expected := []byte("foo Set-Cookie:%20SESSIONID=MaliciousValue ") + var res []byte + + b.Run("withoutNL", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = removeNewLinesBytes(withoutNL) + } + utils.AssertEqual(b, true, bytes.Equal(res, expected)) + }) + + b.Run("withNL", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = removeNewLinesBytes(withNL) + } + utils.AssertEqual(b, true, bytes.Equal(res, expected)) + }) +} + +// go test -v -run=RemoveNewLines_Bytes -count=3 +func Test_RemoveNewLines_Bytes(t *testing.T) { + app := New() + t.Run("Not Status OK", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.SendString("Hello, World!") + c.Status(201) + setETag(c, false) + utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag))) + }) + + t.Run("No Body", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + setETag(c, false) + utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag))) + }) + + t.Run("Has HeaderIfNoneMatch", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.SendString("Hello, World!") + c.Request().Header.Set(HeaderIfNoneMatch, `"13-1831710635"`) + setETag(c, false) + utils.AssertEqual(t, 304, c.Response().StatusCode()) + utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag))) + utils.AssertEqual(t, "", string(c.Response().Body())) + }) + + t.Run("No HeaderIfNoneMatch", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.SendString("Hello, World!") + setETag(c, false) + utils.AssertEqual(t, `"13-1831710635"`, string(c.Response().Header.Peek(HeaderETag))) + }) +} + + // go test -v -run=Test_Utils_ -count=3 func Test_Utils_ETag(t *testing.T) { app := New() diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index c0116544..3017be09 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -8,97 +8,10 @@ import ( "github.com/gofiber/fiber/v2/utils" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Users defines the allowed credentials - // - // Required. Default: map[string]string{} - Users map[string]string - - // Realm is a string to define realm attribute of BasicAuth. - // the realm identifies the system to authenticate against - // and can be used by clients to save credentials - // - // Optional. Default: "Restricted". - Realm string - - // Authorizer defines a function you can pass - // to check the credentials however you want. - // It will be called with a username and password - // and is expected to return true or false to indicate - // that the credentials were approved or not. - // - // Optional. Default: nil. - Authorizer func(string, string) bool - - // Unauthorized defines the response body for unauthorized responses. - // By default it will return with a 401 Unauthorized and the correct WWW-Auth header - // - // Optional. Default: nil - Unauthorized fiber.Handler - - // ContextUser is the key to store the username in Locals - // - // Optional. Default: "username" - ContextUsername string - - // ContextPass is the key to store the password in Locals - // - // Optional. Default: "password" - ContextPassword string -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Users: map[string]string{}, - Realm: "Restricted", - Authorizer: nil, - Unauthorized: nil, - ContextUsername: "username", - ContextPassword: "password", -} - // New creates a new middleware handler func New(config Config) fiber.Handler { - cfg := config - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if cfg.Users == nil { - cfg.Users = ConfigDefault.Users - } - if cfg.Realm == "" { - cfg.Realm = ConfigDefault.Realm - } - if cfg.Authorizer == nil { - cfg.Authorizer = func(user, pass string) bool { - user, exist := cfg.Users[user] - if !exist { - return false - } - return user == pass - } - } - if cfg.Unauthorized == nil { - cfg.Unauthorized = func(c *fiber.Ctx) error { - c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm) - return c.SendStatus(fiber.StatusUnauthorized) - } - } - if cfg.ContextUsername == "" { - cfg.ContextUsername = ConfigDefault.ContextUsername - } - if cfg.ContextPassword == "" { - cfg.ContextPassword = ConfigDefault.ContextPassword - } + // Set default config + cfg := configDefault(config) // Return new handler return func(c *fiber.Ctx) error { diff --git a/middleware/basicauth/config.go b/middleware/basicauth/config.go new file mode 100644 index 00000000..1a4d1fce --- /dev/null +++ b/middleware/basicauth/config.go @@ -0,0 +1,105 @@ +package basicauth + +import ( + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Users defines the allowed credentials + // + // Required. Default: map[string]string{} + Users map[string]string + + // Realm is a string to define realm attribute of BasicAuth. + // the realm identifies the system to authenticate against + // and can be used by clients to save credentials + // + // Optional. Default: "Restricted". + Realm string + + // Authorizer defines a function you can pass + // to check the credentials however you want. + // It will be called with a username and password + // and is expected to return true or false to indicate + // that the credentials were approved or not. + // + // Optional. Default: nil. + Authorizer func(string, string) bool + + // Unauthorized defines the response body for unauthorized responses. + // By default it will return with a 401 Unauthorized and the correct WWW-Auth header + // + // Optional. Default: nil + Unauthorized fiber.Handler + + // ContextUser is the key to store the username in Locals + // + // Optional. Default: "username" + ContextUsername string + + // ContextPass is the key to store the password in Locals + // + // Optional. Default: "password" + ContextPassword string +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Users: map[string]string{}, + Realm: "Restricted", + Authorizer: nil, + Unauthorized: nil, + ContextUsername: "username", + ContextPassword: "password", +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if cfg.Users == nil { + cfg.Users = ConfigDefault.Users + } + if cfg.Realm == "" { + cfg.Realm = ConfigDefault.Realm + } + if cfg.Authorizer == nil { + cfg.Authorizer = func(user, pass string) bool { + user, exist := cfg.Users[user] + if !exist { + return false + } + return user == pass + } + } + if cfg.Unauthorized == nil { + cfg.Unauthorized = func(c *fiber.Ctx) error { + c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm) + return c.SendStatus(fiber.StatusUnauthorized) + } + } + if cfg.ContextUsername == "" { + cfg.ContextUsername = ConfigDefault.ContextUsername + } + if cfg.ContextPassword == "" { + cfg.ContextPassword = ConfigDefault.ContextPassword + } + return cfg +} diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index f374a67b..8300498a 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -12,74 +12,10 @@ import ( "github.com/gofiber/fiber/v2/utils" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Expiration is the time that an cached response will live - // - // Optional. Default: 1 * time.Minute - Expiration time.Duration - - // CacheControl enables client side caching if set to true - // - // Optional. Default: false - CacheControl bool - - // Key allows you to generate custom keys, by default c.Path() is used - // - // Default: func(c *fiber.Ctx) string { - // return c.Path() - // } - Key func(*fiber.Ctx) string - - // Store is used to store the state of the middleware - // - // Default: an in memory store for this process only - Store fiber.Storage - - // Internally used - if true, the simpler method of two maps is used in order to keep - // execution time down. - defaultStore bool -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Expiration: 1 * time.Minute, - CacheControl: false, - Key: func(c *fiber.Ctx) string { - return c.Path() - }, - defaultStore: true, -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if int(cfg.Expiration.Seconds()) == 0 { - cfg.Expiration = ConfigDefault.Expiration - } - if cfg.Key == nil { - cfg.Key = ConfigDefault.Key - } - if cfg.Store == nil { - cfg.defaultStore = true - } - } + cfg := configDefault(config...) var ( // Cache settings @@ -152,7 +88,7 @@ func New(config ...Config) fiber.Handler { } else { // Load data from store - storeEntry, err := cfg.Store.Get(key) + storeEntry, err := cfg.Storage.Get(key) if err != nil { return err } @@ -165,7 +101,7 @@ func New(config ...Config) fiber.Handler { } } - if entryBody, err = cfg.Store.Get(key + "_body"); err != nil { + if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil { return err } } @@ -183,10 +119,10 @@ func New(config ...Config) fiber.Handler { if cfg.defaultStore { delete(entries, key) } else { // Use custom storage - if err := cfg.Store.Delete(key); err != nil { + if err := cfg.Storage.Delete(key); err != nil { return err } - if err := cfg.Store.Delete(key + "_body"); err != nil { + if err := cfg.Storage.Delete(key + "_body"); err != nil { return err } } @@ -234,12 +170,12 @@ func New(config ...Config) fiber.Handler { } // Pass bytes to Storage - if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil { + if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil { return err } // Pass bytes to Storage - if err = cfg.Store.Set(key+"_body", entryBody, cfg.Expiration); err != nil { + if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil { return err } } diff --git a/middleware/cache/config.go b/middleware/cache/config.go new file mode 100644 index 00000000..10a3017b --- /dev/null +++ b/middleware/cache/config.go @@ -0,0 +1,87 @@ +package cache + +import ( + "fmt" + "time" + + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Expiration is the time that an cached response will live + // + // Optional. Default: 1 * time.Minute + Expiration time.Duration + + // CacheControl enables client side caching if set to true + // + // Optional. Default: false + CacheControl bool + + // Key allows you to generate custom keys, by default c.Path() is used + // + // Default: func(c *fiber.Ctx) string { + // return c.Path() + // } + Key func(*fiber.Ctx) string + + // Deprecated, use Storage instead + Store fiber.Storage + + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Storage fiber.Storage + + // Internally used - if true, the simpler method of two maps is used in order to keep + // execution time down. + defaultStore bool +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Expiration: 1 * time.Minute, + CacheControl: false, + Key: func(c *fiber.Ctx) string { + return c.Path() + }, + defaultStore: true, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if int(cfg.Expiration.Seconds()) == 0 { + cfg.Expiration = ConfigDefault.Expiration + } + if cfg.Key == nil { + cfg.Key = ConfigDefault.Key + } + if cfg.Storage == nil && cfg.Store == nil { + cfg.defaultStore = true + } + if cfg.Store != nil { + fmt.Println("cache: `Store` is deprecated, use `Storage` instead") + cfg.Storage = cfg.Store + cfg.defaultStore = true + } + return cfg +} diff --git a/middleware/compress/compress.go b/middleware/compress/compress.go index 80919ab0..e65d7855 100644 --- a/middleware/compress/compress.go +++ b/middleware/compress/compress.go @@ -5,54 +5,10 @@ import ( "github.com/valyala/fasthttp" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Level determines the compression algorithm - // - // Optional. Default: LevelDefault - // LevelDisabled: -1 - // LevelDefault: 0 - // LevelBestSpeed: 1 - // LevelBestCompression: 2 - Level Level -} - -// Level is numeric representation of compression level -type Level int - -// Represents compression level that will be used in the middleware -const ( - LevelDisabled Level = -1 - LevelDefault Level = 0 - LevelBestSpeed Level = 1 - LevelBestCompression Level = 2 -) - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Level: LevelDefault, -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression { - cfg.Level = ConfigDefault.Level - } - } + cfg := configDefault(config...) // Setup request handlers var ( diff --git a/middleware/compress/config.go b/middleware/compress/config.go new file mode 100644 index 00000000..5495ad4c --- /dev/null +++ b/middleware/compress/config.go @@ -0,0 +1,56 @@ +package compress + +import ( + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Level determines the compression algorithm + // + // Optional. Default: LevelDefault + // LevelDisabled: -1 + // LevelDefault: 0 + // LevelBestSpeed: 1 + // LevelBestCompression: 2 + Level Level +} + +// Level is numeric representation of compression level +type Level int + +// Represents compression level that will be used in the middleware +const ( + LevelDisabled Level = -1 + LevelDefault Level = 0 + LevelBestSpeed Level = 1 + LevelBestCompression Level = 2 +) + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Level: LevelDefault, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression { + cfg.Level = ConfigDefault.Level + } + return cfg +} diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index c99297a9..3ab10612 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -40,9 +40,14 @@ type Config struct { // Expiration is the duration before csrf token will expire // - // Optional. Default: 24 * time.Hour + // Optional. Default: 1 * time.Hour Expiration time.Duration + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Storage fiber.Storage + // Context key to store generated CSRF token into context. // // Optional. Default value "csrf". @@ -58,8 +63,8 @@ var ConfigDefault = Config{ Name: "_csrf", SameSite: "Strict", }, - Expiration: 24 * time.Hour, - CookieExpires: 24 * time.Hour, // deprecated + Expiration: 1 * time.Hour, + CookieExpires: 1 * time.Hour, // deprecated } type storage struct { diff --git a/middleware/limiter/config.go b/middleware/limiter/config.go new file mode 100644 index 00000000..c5c0021c --- /dev/null +++ b/middleware/limiter/config.go @@ -0,0 +1,103 @@ +package limiter + +import ( + "fmt" + "time" + + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Max number of recent connections during `Duration` seconds before sending a 429 response + // + // Default: 5 + Max int + + // Key allows you to generate custom keys, by default c.IP() is used + // + // Default: func(c *fiber.Ctx) string { + // return c.IP() + // } + Key func(*fiber.Ctx) string + + // Expiration is the time on how long to keep records of requests in memory + // + // Default: 1 * time.Minute + Expiration time.Duration + + // LimitReached is called when a request hits the limit + // + // Default: func(c *fiber.Ctx) error { + // return c.SendStatus(fiber.StatusTooManyRequests) + // } + LimitReached fiber.Handler + + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Storage fiber.Storage + + // DEPRECATED: Use Expiration instead + Duration time.Duration + + // DEPRECATED, use Storage instead + Store fiber.Storage +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Max: 5, + Expiration: 1 * time.Minute, + Key: func(c *fiber.Ctx) string { + return c.IP() + }, + LimitReached: func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusTooManyRequests) + }, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if cfg.Max <= 0 { + cfg.Max = ConfigDefault.Max + } + if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 { + cfg.Expiration = ConfigDefault.Expiration + } + if int(cfg.Duration.Seconds()) > 0 { + fmt.Println("[LIMITER] Duration is deprecated, please use Expiration") + if cfg.Expiration != ConfigDefault.Expiration { + cfg.Expiration = cfg.Duration + } + } + if cfg.Key == nil { + cfg.Key = ConfigDefault.Key + } + if cfg.LimitReached == nil { + cfg.LimitReached = ConfigDefault.LimitReached + } + if cfg.Store != nil { + fmt.Println("[LIMITER] Store is deprecated, please use Storage") + cfg.Storage = cfg.Store + } + return cfg +} diff --git a/middleware/limiter/limiter.go b/middleware/limiter/limiter.go index 42ab9801..bc6bfa77 100644 --- a/middleware/limiter/limiter.go +++ b/middleware/limiter/limiter.go @@ -10,69 +10,11 @@ import ( "github.com/gofiber/fiber/v2" ) -//go:generate msgp -unexported -//msgp:ignore Config - -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Max number of recent connections during `Duration` seconds before sending a 429 response - // - // Default: 5 - Max int - - // DEPRECATED: Use Expiration instead - Duration time.Duration - - // Expiration is the time on how long to keep records of requests in memory - // - // Default: 1 * time.Minute - Expiration time.Duration - - // Key allows you to generate custom keys, by default c.IP() is used - // - // Default: func(c *fiber.Ctx) string { - // return c.IP() - // } - Key func(*fiber.Ctx) string - - // LimitReached is called when a request hits the limit - // - // Default: func(c *fiber.Ctx) error { - // return c.SendStatus(fiber.StatusTooManyRequests) - // } - LimitReached fiber.Handler - - // Store is used to store the state of the middleware - // - // Default: an in memory store for this process only - Store fiber.Storage - - // Internally used - if true, the simpler method of two maps is used in order to keep - // execution time down. - defaultStore bool -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Max: 5, - Expiration: 1 * time.Minute, - Key: func(c *fiber.Ctx) string { - return c.IP() - }, - LimitReached: func(c *fiber.Ctx) error { - return c.SendStatus(fiber.StatusTooManyRequests) - }, - defaultStore: true, -} - -// X-RateLimit-* headers const ( + // Storage ErrNotExist + errNotExist = "key does not exist" + + // X-RateLimit-* headers xRateLimitLimit = "X-RateLimit-Limit" xRateLimitRemaining = "X-RateLimit-Remaining" xRateLimitReset = "X-RateLimit-Reset" @@ -81,38 +23,7 @@ const ( // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if cfg.Max <= 0 { - cfg.Max = ConfigDefault.Max - } - if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 { - cfg.Expiration = ConfigDefault.Expiration - } - if int(cfg.Duration.Seconds()) > 0 { - fmt.Println("[LIMITER] Duration is deprecated, please use Expiration") - if cfg.Expiration != ConfigDefault.Expiration { - cfg.Expiration = cfg.Duration - } - } - if cfg.Key == nil { - cfg.Key = ConfigDefault.Key - } - if cfg.LimitReached == nil { - cfg.LimitReached = ConfigDefault.LimitReached - } - if cfg.Store == nil { - cfg.defaultStore = true - } - } + cfg := configDefault(config...) var ( // Limiter settings @@ -150,21 +61,19 @@ func New(config ...Config) fiber.Handler { mux.Lock() defer mux.Unlock() - // Use default memory storage - if cfg.defaultStore { - entry = entries[key] - } else { // Use custom storage - storeEntry, err := cfg.Store.Get(key) - if err != nil { - return err - } - // Only decode if we found an entry - if storeEntry != nil { - // Decode bytes using msgp - if _, err := entry.UnmarshalMsg(storeEntry); err != nil { + // Use Storage if provided + if cfg.Storage != nil { + val, err := cfg.Storage.Get(key) + if val != nil && len(val) > 0 { + if _, err := entry.UnmarshalMsg(val); err != nil { return err } } + if err != nil && err.Error() != errNotExist { + fmt.Println("[LIMITER]", err.Error()) + } + } else { + entry = entries[key] } // Get timestamp @@ -183,19 +92,20 @@ func New(config ...Config) fiber.Handler { // Increment hits entry.hits++ - // Use default memory storage - if cfg.defaultStore { - entries[key] = entry - } else { // Use custom storage - data, err := entry.MarshalMsg(nil) + // Use Storage if provided + if cfg.Storage != nil { + // Marshal entry to bytes + val, err := entry.MarshalMsg(nil) if err != nil { return err } - // Pass bytes to Storage - if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil { + // Pass value to Storage + if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil { return err } + } else { + entries[key] = entry } // Calculate when it resets in seconds @@ -223,10 +133,3 @@ func New(config ...Config) fiber.Handler { return c.Next() } } - -// replacer for strconv.FormatUint -// func appendInt(buf *bytebufferpool.ByteBuffer, v int) (int, error) { -// old := len(buf.B) -// buf.B = fasthttp.AppendUint(buf.B, v) -// return len(buf.B) - old, nil -// } diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index f2daf9b2..1ffae449 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/gofiber/fiber/v2/utils" - "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/internal/storage/memory" + "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" ) @@ -24,7 +24,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) { app.Use(New(Config{ Max: 50, Expiration: 2 * time.Second, - Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)}, + Storage: memory.New(), })) app.Get("/", func(c *fiber.Ctx) error { @@ -108,32 +108,6 @@ func Test_Limiter_Concurrency(t *testing.T) { } -// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4 -func Benchmark_Limiter(b *testing.B) { - app := fiber.New() - - app.Use(New(Config{ - Max: 100, - Expiration: 60 * time.Second, - })) - - app.Get("/", func(c *fiber.Ctx) error { - return c.SendString("Hello, World!") - }) - - h := app.Handler() - - fctx := &fasthttp.RequestCtx{} - fctx.Request.Header.SetMethod("GET") - fctx.Request.SetRequestURI("/") - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - h(fctx) - } -} - // go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4 func Benchmark_Limiter_Custom_Store(b *testing.B) { app := fiber.New() @@ -141,7 +115,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) { app.Use(New(Config{ Max: 100, Expiration: 60 * time.Second, - Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)}, + Storage: memory.New(), })) app.Get("/", func(c *fiber.Ctx) error { @@ -203,39 +177,28 @@ func Test_Limiter_Headers(t *testing.T) { } } -// testStore is used for testing custom stores -type testStore struct { - stmap map[string][]byte - mutex *sync.Mutex -} +// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4 +func Benchmark_Limiter(b *testing.B) { + app := fiber.New() -func (s testStore) Get(id string) ([]byte, error) { - s.mutex.Lock() - val, ok := s.stmap[id] - s.mutex.Unlock() - if !ok { - return nil, nil - } else { - return val, nil + app.Use(New(Config{ + Max: 100, + Expiration: 60 * time.Second, + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + h := app.Handler() + + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod("GET") + fctx.Request.SetRequestURI("/") + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + h(fctx) } } - -func (s testStore) Set(id string, val []byte, _ time.Duration) error { - s.mutex.Lock() - s.stmap[id] = val - s.mutex.Unlock() - - return nil -} - -func (s testStore) Reset() error { - return nil -} - -func (s testStore) Delete(id string) error { - return nil -} - -func (s testStore) Close() error { - return nil -} diff --git a/middleware/logger/logger.go b/middleware/logger/logger.go index d16b631e..be9e11a9 100644 --- a/middleware/logger/logger.go +++ b/middleware/logger/logger.go @@ -186,6 +186,7 @@ func New(config ...Config) fiber.Handler { var ( start, stop time.Time once sync.Once + mu sync.Mutex errHandler fiber.ErrorHandler ) @@ -362,6 +363,7 @@ func New(config ...Config) fiber.Handler { if err != nil { _, _ = buf.WriteString(err.Error()) } + mu.Lock() // Write buffer to output if _, err := cfg.Output.Write(buf.Bytes()); err != nil { // Write error to output @@ -370,6 +372,7 @@ func New(config ...Config) fiber.Handler { // TODO: What should we do here? } } + mu.Unlock() // Put buffer back to pool bytebufferpool.Put(buf)