From c4d2876d64590bb6ddfbcc9c21253c927f35dbca Mon Sep 17 00:00:00 2001 From: James Lucas Date: Fri, 21 Apr 2023 12:37:53 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(cors):=20Changed=20condition?= =?UTF-8?q?=20for=20'AllowOriginsFunc'=20(#2423)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🐛 fix(cors): Changed condition for 'AllowOriginsFunc' to check against default config value of 'AllowOrigins' --- middleware/cors/cors.go | 4 ++-- middleware/cors/cors_test.go | 44 ++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index cf90aee2..181134ba 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -97,7 +97,7 @@ func New(config ...Config) fiber.Handler { } // Warning logs if both AllowOrigins and AllowOriginsFunc are set - if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil { + if cfg.AllowOrigins != ConfigDefault.AllowOrigins && cfg.AllowOriginsFunc != nil { log.Printf("[CORS] - [Warning] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.\n") } @@ -142,7 +142,7 @@ func New(config ...Config) fiber.Handler { // Run AllowOriginsFunc if the logic for // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. - if allowOrigin == "" && cfg.AllowOriginsFunc != nil { + if (allowOrigin == "" || allowOrigin == ConfigDefault.AllowOrigins) && cfg.AllowOriginsFunc != nil { if cfg.AllowOriginsFunc(origin) { allowOrigin = origin } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 1f4a3c91..15b51a95 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -244,7 +244,7 @@ func Test_CORS_Next(t *testing.T) { utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) } -func Test_CORS_AllowOriginsFunc(t *testing.T) { +func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() @@ -267,7 +267,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { // Perform request handler(ctx) - // Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com + // Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")' utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() @@ -294,3 +294,43 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } + +func Test_CORS_AllowOriginsFunc(t *testing.T) { + t.Parallel() + // New fiber instance + app := fiber.New() + app.Use("/", New(Config{ + AllowOriginsFunc: func(origin string) bool { + return strings.Contains(origin, "example-2") + }, + })) + + // Get handler pointer + handler := app.Handler() + + // Make request with disallowed origin + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") + + // Perform request + handler(ctx) + + // Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")' + // and AllowOrigins has not been set so the default "*" is used + utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) + + ctx.Request.Reset() + ctx.Response.Reset() + + // Make request with allowed origin + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") + + handler(ctx) + + // Allow-Origin header should be "http://example-2.com" + utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) +}