mirror of https://github.com/gofiber/fiber.git
fix(middleware/cors): Categorize requests correctly (#2921)
* fix(middleware/cors): categorise requests correctly * test(middleware/cors): improve test coverage for request types * test(middleware/cors): Add subdomain matching tests * test(middleware/cors): parallel tests for CORS headers based on request type * test(middleware/cors): Add benchmark for CORS subdomain matching * test(middleware/cors): cover additiona test cases * refactor(middleware/cors): origin validation and normalizationpull/2932/head
parent
1aac6f618b
commit
1607d872d9
|
@ -119,33 +119,23 @@ func New(config ...Config) fiber.Handler {
|
||||||
allowSOrigins := []subdomain{}
|
allowSOrigins := []subdomain{}
|
||||||
allowAllOrigins := false
|
allowAllOrigins := false
|
||||||
|
|
||||||
// processOrigin processes an origin string, normalizes it and checks its validity
|
|
||||||
// it will panic if the origin is invalid
|
|
||||||
processOrigin := func(origin string) (string, bool) {
|
|
||||||
trimmedOrigin := strings.TrimSpace(origin)
|
|
||||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
|
||||||
if !isValid {
|
|
||||||
log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin)
|
|
||||||
panic("[CORS] Invalid origin provided in configuration")
|
|
||||||
}
|
|
||||||
return normalizedOrigin, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate and normalize static AllowOrigins
|
// Validate and normalize static AllowOrigins
|
||||||
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
|
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
|
||||||
origins := strings.Split(cfg.AllowOrigins, ",")
|
origins := strings.Split(cfg.AllowOrigins, ",")
|
||||||
for _, origin := range origins {
|
for _, origin := range origins {
|
||||||
if i := strings.Index(origin, "://*."); i != -1 {
|
if i := strings.Index(origin, "://*."); i != -1 {
|
||||||
normalizedOrigin, isValid := processOrigin(origin[:i+3] + origin[i+4:])
|
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
|
||||||
|
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||||
if !isValid {
|
if !isValid {
|
||||||
continue
|
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||||
}
|
}
|
||||||
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
|
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
|
||||||
allowSOrigins = append(allowSOrigins, sd)
|
allowSOrigins = append(allowSOrigins, sd)
|
||||||
} else {
|
} else {
|
||||||
normalizedOrigin, isValid := processOrigin(origin)
|
trimmedOrigin := strings.TrimSpace(origin)
|
||||||
|
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||||
if !isValid {
|
if !isValid {
|
||||||
continue
|
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||||
}
|
}
|
||||||
allowOrigins = append(allowOrigins, normalizedOrigin)
|
allowOrigins = append(allowOrigins, normalizedOrigin)
|
||||||
}
|
}
|
||||||
|
@ -172,8 +162,9 @@ func New(config ...Config) fiber.Handler {
|
||||||
// Get originHeader header
|
// Get originHeader header
|
||||||
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
||||||
|
|
||||||
// If the request does not have an Origin header, the request is outside the scope of CORS
|
// If the request does not have Origin and Access-Control-Request-Method
|
||||||
if originHeader == "" {
|
// headers, the request is outside the scope of CORS
|
||||||
|
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,8 +202,9 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple request
|
// Simple request
|
||||||
|
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
|
||||||
if c.Method() != fiber.MethodOptions {
|
if c.Method() != fiber.MethodOptions {
|
||||||
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
|
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,14 +225,14 @@ func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, expos
|
||||||
|
|
||||||
if cfg.AllowCredentials {
|
if cfg.AllowCredentials {
|
||||||
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
|
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
|
||||||
if allowOrigin != "*" && allowOrigin != "" {
|
if allowOrigin == "*" {
|
||||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
|
||||||
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
|
||||||
} else if allowOrigin == "*" {
|
|
||||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
|
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
|
||||||
|
} else if allowOrigin != "" {
|
||||||
|
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
|
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
||||||
}
|
}
|
||||||
} else if len(allowOrigin) > 0 {
|
} else if allowOrigin != "" {
|
||||||
// For non-credential requests, it's safe to set to '*' or specific origins
|
// For non-credential requests, it's safe to set to '*' or specific origins
|
||||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) {
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
app.Handler()(ctx)
|
app.Handler()(ctx)
|
||||||
|
|
||||||
|
@ -49,6 +50,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
|
||||||
// Test default GET response headers
|
// Test default GET response headers
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
|
||||||
|
@ -59,6 +61,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
|
||||||
// Test default OPTIONS (preflight) response headers
|
// Test default OPTIONS (preflight) response headers
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
h(ctx)
|
h(ctx)
|
||||||
|
|
||||||
|
@ -87,6 +90,7 @@ func Test_CORS_Wildcard(t *testing.T) {
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
|
|
||||||
// Perform request
|
// Perform request
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -101,6 +105,7 @@ func Test_CORS_Wildcard(t *testing.T) {
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
|
||||||
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
|
||||||
|
@ -128,6 +133,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
|
|
||||||
// Perform request
|
// Perform request
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -141,6 +147,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
|
||||||
// Test non OPTIONS (preflight) response headers
|
// Test non OPTIONS (preflight) response headers
|
||||||
ctx = &fasthttp.RequestCtx{}
|
ctx = &fasthttp.RequestCtx{}
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
|
||||||
|
@ -183,7 +190,9 @@ func Test_CORS_Invalid_Origins_Panic(t *testing.T) {
|
||||||
"http://foo.[a-z]*.example.com",
|
"http://foo.[a-z]*.example.com",
|
||||||
"http://*",
|
"http://*",
|
||||||
"https://*",
|
"https://*",
|
||||||
|
"http://*.com*",
|
||||||
"invalid url",
|
"invalid url",
|
||||||
|
"http://origin.com,invalid url",
|
||||||
// add more invalid origins as needed
|
// add more invalid origins as needed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,6 +235,7 @@ func Test_CORS_Subdomain(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
|
||||||
|
|
||||||
// Perform request
|
// Perform request
|
||||||
|
@ -240,6 +250,7 @@ func Test_CORS_Subdomain(t *testing.T) {
|
||||||
// Make request with domain only (disallowed)
|
// Make request with domain only (disallowed)
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -252,6 +263,7 @@ func Test_CORS_Subdomain(t *testing.T) {
|
||||||
// Make request with allowed origin
|
// Make request with allowed origin
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -270,6 +282,11 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
|
||||||
reqOrigin: "http://example.com",
|
reqOrigin: "http://example.com",
|
||||||
shouldAllowOrigin: true,
|
shouldAllowOrigin: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
pattern: "HTTP://EXAMPLE.COM",
|
||||||
|
reqOrigin: "http://example.com",
|
||||||
|
shouldAllowOrigin: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
pattern: "https://example.com",
|
pattern: "https://example.com",
|
||||||
reqOrigin: "https://example.com",
|
reqOrigin: "https://example.com",
|
||||||
|
@ -300,6 +317,11 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
|
||||||
reqOrigin: "http://aaa.example.com:8080",
|
reqOrigin: "http://aaa.example.com:8080",
|
||||||
shouldAllowOrigin: true,
|
shouldAllowOrigin: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
pattern: "http://*.example.com",
|
||||||
|
reqOrigin: "http://1.2.aaa.example.com",
|
||||||
|
shouldAllowOrigin: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
pattern: "http://example.com",
|
pattern: "http://example.com",
|
||||||
reqOrigin: "http://gofiber.com",
|
reqOrigin: "http://gofiber.com",
|
||||||
|
@ -366,6 +388,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
|
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -422,6 +445,103 @@ func Test_CORS_Next(t *testing.T) {
|
||||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// go test -run Test_CORS_Headers_BasedOnRequestType
|
||||||
|
func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
app := fiber.New()
|
||||||
|
app.Use(New(Config{}))
|
||||||
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
|
return c.SendStatus(fiber.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
methods := []string{
|
||||||
|
fiber.MethodGet,
|
||||||
|
fiber.MethodPost,
|
||||||
|
fiber.MethodPut,
|
||||||
|
fiber.MethodDelete,
|
||||||
|
fiber.MethodPatch,
|
||||||
|
fiber.MethodHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get handler pointer
|
||||||
|
handler := app.Handler()
|
||||||
|
|
||||||
|
t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Make request without origin header, and without Access-Control-Request-Method
|
||||||
|
for _, method := range methods {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.Header.SetMethod(method)
|
||||||
|
ctx.Request.SetRequestURI("https://example.com/")
|
||||||
|
handler(ctx)
|
||||||
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||||
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Make request with origin header, but without Access-Control-Request-Method
|
||||||
|
for _, method := range methods {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.Header.SetMethod(method)
|
||||||
|
ctx.Request.SetRequestURI("https://example.com/")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||||
|
handler(ctx)
|
||||||
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||||
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Make request without origin header, but with Access-Control-Request-Method
|
||||||
|
for _, method := range methods {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.Header.SetMethod(method)
|
||||||
|
ctx.Request.SetRequestURI("https://example.com/")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
|
handler(ctx)
|
||||||
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||||
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Make preflight request with origin header and with Access-Control-Request-Method
|
||||||
|
for _, method := range methods {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.SetRequestURI("https://example.com/")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
|
||||||
|
handler(ctx)
|
||||||
|
utils.AssertEqual(t, 204, ctx.Response.StatusCode(), "Status code should be 204")
|
||||||
|
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
|
||||||
|
utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)")
|
||||||
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Make non-preflight request with origin header and with Access-Control-Request-Method
|
||||||
|
for _, method := range methods {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.Header.SetMethod(method)
|
||||||
|
ctx.Request.SetRequestURI("https://example.com/api/action")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
|
||||||
|
handler(ctx)
|
||||||
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||||
|
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
|
||||||
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)")
|
||||||
|
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
|
func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
// New fiber instance
|
// New fiber instance
|
||||||
|
@ -440,6 +560,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
|
||||||
|
|
||||||
// Perform request
|
// Perform request
|
||||||
|
@ -454,6 +575,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
|
||||||
// Make request with allowed origin
|
// Make request with allowed origin
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -466,6 +588,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
|
||||||
// Make request with allowed origin
|
// Make request with allowed origin
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -505,6 +628,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
|
||||||
// Make request with allowed origin
|
// Make request with allowed origin
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
|
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -569,7 +693,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed",
|
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowOrigins: "http://aaa.com",
|
AllowOrigins: "http://aaa.com",
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -580,7 +704,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed",
|
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowOrigins: "http://aaa.com",
|
AllowOrigins: "http://aaa.com",
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -591,7 +715,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed",
|
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowOrigins: "http://aaa.com",
|
AllowOrigins: "http://aaa.com",
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return false
|
return false
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -602,7 +726,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
|
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowOrigins: "http://aaa.com",
|
AllowOrigins: "http://aaa.com",
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return false
|
return false
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -622,7 +746,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed",
|
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowOrigins: "",
|
AllowOrigins: "",
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -633,7 +757,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
|
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowOrigins: "",
|
AllowOrigins: "",
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return false
|
return false
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -652,6 +776,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
|
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
@ -674,7 +799,7 @@ func Test_CORS_AllowCredentials(t *testing.T) {
|
||||||
Name: "AllowOriginsFuncDefined",
|
Name: "AllowOriginsFuncDefined",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -687,7 +812,7 @@ func Test_CORS_AllowCredentials(t *testing.T) {
|
||||||
Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials",
|
Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials",
|
||||||
Config: Config{
|
Config: Config{
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
AllowOriginsFunc: func(origin string) bool {
|
AllowOriginsFunc: func(_ string) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -742,6 +867,7 @@ func Test_CORS_AllowCredentials(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Request.SetRequestURI("/")
|
ctx.Request.SetRequestURI("/")
|
||||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||||
|
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||||
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
|
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
|
||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
|
|
@ -2,6 +2,8 @@ package cors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// go test -run -v Test_normalizeOrigin
|
// go test -run -v Test_normalizeOrigin
|
||||||
|
@ -16,6 +18,9 @@ func Test_normalizeOrigin(t *testing.T) {
|
||||||
{"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved.
|
{"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved.
|
||||||
{"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed.
|
{"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed.
|
||||||
{"http://", false, ""}, // Invalid origin should not be accepted.
|
{"http://", false, ""}, // Invalid origin should not be accepted.
|
||||||
|
{"file:///etc/passwd", false, ""}, // File scheme should not be accepted.
|
||||||
|
{"https://*example.com", false, ""}, // Wildcard domain should not be accepted.
|
||||||
|
{"http://*.example.com", false, ""}, // Wildcard subdomain should not be accepted.
|
||||||
{"http://example.com/path", false, ""}, // Path should not be accepted.
|
{"http://example.com/path", false, ""}, // Path should not be accepted.
|
||||||
{"http://example.com?query=123", false, ""}, // Query should not be accepted.
|
{"http://example.com?query=123", false, ""}, // Query should not be accepted.
|
||||||
{"http://example.com#fragment", false, ""}, // Fragment should not be accepted.
|
{"http://example.com#fragment", false, ""}, // Fragment should not be accepted.
|
||||||
|
@ -105,3 +110,86 @@ func Test_normalizeDomain(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4
|
||||||
|
func Benchmark_CORS_SubdomainMatch(b *testing.B) {
|
||||||
|
s := subdomain{
|
||||||
|
prefix: "www",
|
||||||
|
suffix: ".example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
o := "www.example.com"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
s.match(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_CORS_SubdomainMatch(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sub subdomain
|
||||||
|
origin string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match with different scheme",
|
||||||
|
sub: subdomain{prefix: "http://api.", suffix: ".example.com"},
|
||||||
|
origin: "https://api.service.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match with different scheme",
|
||||||
|
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||||
|
origin: "http://api.service.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match with valid subdomain",
|
||||||
|
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||||
|
origin: "https://api.service.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match with valid nested subdomain",
|
||||||
|
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||||
|
origin: "https://1.2.api.service.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "no match with invalid prefix",
|
||||||
|
sub: subdomain{prefix: "https://abc.", suffix: ".example.com"},
|
||||||
|
origin: "https://service.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match with invalid suffix",
|
||||||
|
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||||
|
origin: "https://api.example.org",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match with empty origin",
|
||||||
|
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||||
|
origin: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial match not considered a match",
|
||||||
|
sub: subdomain{prefix: "https://service.", suffix: ".example.com"},
|
||||||
|
origin: "https://api.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.sub.match(tt.origin)
|
||||||
|
utils.AssertEqual(t, tt.expected, got, "subdomain.match()")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue