test(middleware/csrf): Fix Benchmark Tests (#2932)

* test(middleware/csrf): fix Benchmark_Middleware_CSRF_*

* fix(middleware/csrf): update refererMatchesHost()
pull/2937/head
Jason McNeil 2024-03-25 11:30:20 -03:00 committed by GitHub
parent 1607d872d9
commit ba10e68d01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 10 deletions

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -220,7 +221,7 @@ func isCsrfFromCookie(extractor interface{}) bool {
// returns an error if the referer header is not present or is invalid // returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid // returns nil if the referer header is valid
func refererMatchesHost(c *fiber.Ctx) error { func refererMatchesHost(c *fiber.Ctx) error {
referer := c.Get(fiber.HeaderReferer) referer := strings.ToLower(c.Get(fiber.HeaderReferer))
if referer == "" { if referer == "" {
return ErrNoReferer return ErrNoReferer
} }
@ -230,9 +231,9 @@ func refererMatchesHost(c *fiber.Ctx) error {
return ErrBadReferer return ErrBadReferer
} }
if refererURL.Scheme+"://"+refererURL.Host != c.Protocol()+"://"+c.Hostname() { if refererURL.Scheme == c.Protocol() && refererURL.Host == c.Hostname() {
return ErrBadReferer return nil
} }
return nil return ErrBadReferer
} }

View File

@ -992,7 +992,10 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
return c.SendStatus(fiber.StatusTeapot) return c.SendStatus(fiber.StatusTeapot)
}) })
fctx := &fasthttp.RequestCtx{} app.Post("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler() h := app.Handler()
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -1002,17 +1005,27 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1] token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test Correct Referer POST
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
h(fctx) h(ctx)
} }
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) utils.AssertEqual(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
} }
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
@ -1024,7 +1037,6 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
return c.SendStatus(fiber.StatusTeapot) return c.SendStatus(fiber.StatusTeapot)
}) })
fctx := &fasthttp.RequestCtx{}
h := app.Handler() h := app.Handler()
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -1034,8 +1046,9 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
h(fctx) h(ctx)
} }
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) // Ensure the GET request returns a 418 status code
utils.AssertEqual(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
} }