🐛 fix(cors): Changed condition for 'AllowOriginsFunc' (#2423)

🐛 fix(cors): Changed condition for 'AllowOriginsFunc' to check against default config value of 'AllowOrigins'
This commit is contained in:
James Lucas 2023-04-21 12:37:53 +01:00 committed by GitHub
parent 3e9575b0fe
commit c4d2876d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 4 deletions

View File

@ -97,7 +97,7 @@ func New(config ...Config) fiber.Handler {
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOrigins != ConfigDefault.AllowOrigins && cfg.AllowOriginsFunc != nil {
log.Printf("[CORS] - [Warning] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.\n")
}
@ -142,7 +142,7 @@ func New(config ...Config) fiber.Handler {
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if (allowOrigin == "" || allowOrigin == ConfigDefault.AllowOrigins) && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
}

View File

@ -244,7 +244,7 @@ func Test_CORS_Next(t *testing.T) {
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_CORS_AllowOriginsFunc(t *testing.T) {
func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
@ -267,7 +267,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
// Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")'
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
@ -294,3 +294,43 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set so the default "*" is used
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
// Allow-Origin header should be "http://example-2.com"
utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}