mirror of https://github.com/gofiber/fiber.git
fix(middleware/cors): CORS handling (#2937)
* fix(middleware/cors): CORS handling * fix(middleware/cors): Vary header handling * test(middleware/cors): Ensure Vary Headers checkedpull/2939/head
parent
43d5091967
commit
e574c0db52
|
@ -162,9 +162,19 @@ func New(config ...Config) fiber.Handler {
|
|||
// Get originHeader header
|
||||
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
||||
|
||||
// If the request does not have Origin and Access-Control-Request-Method
|
||||
// headers, the request is outside the scope of CORS
|
||||
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
|
||||
// If the request does not have Origin header, the request is outside the scope of CORS
|
||||
if originHeader == "" {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
// Unless all origins are allowed, we include the Vary header to cache the response correctly
|
||||
if !allowAllOrigins {
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
|
||||
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
@ -204,13 +214,23 @@ func New(config ...Config) fiber.Handler {
|
|||
// Simple request
|
||||
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
|
||||
if c.Method() != fiber.MethodOptions {
|
||||
if !allowAllOrigins {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Preflight request
|
||||
// Pre-flight request
|
||||
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// of preflight responses:
|
||||
c.Vary(fiber.HeaderAccessControlRequestMethod)
|
||||
c.Vary(fiber.HeaderAccessControlRequestHeaders)
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
|
||||
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
|
||||
|
||||
|
@ -221,8 +241,6 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// Function to set CORS headers
|
||||
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
|
||||
if cfg.AllowCredentials {
|
||||
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
|
||||
if allowOrigin == "*" {
|
||||
|
|
|
@ -50,7 +50,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
|
|||
// Test default GET response headers
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
h(ctx)
|
||||
|
||||
|
@ -70,6 +69,44 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
|
|||
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
|
||||
}
|
||||
|
||||
func Test_CORS_AllowOrigins_Vary(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(
|
||||
Config{
|
||||
AllowOrigins: "http://localhost",
|
||||
},
|
||||
))
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
// Test Vary header non-Cors request
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set for Origin")
|
||||
|
||||
// Test Vary header Cors preflight request
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod(fiber.MethodOptions)
|
||||
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
h(ctx)
|
||||
vh := string(ctx.Response.Header.Peek(fiber.HeaderVary))
|
||||
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderOrigin), "Vary header should be set for Origin")
|
||||
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestMethod), "Vary header should be set for Access-Control-Request-Method")
|
||||
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestHeaders), "Vary header should be set for Access-Control-Request-Headers")
|
||||
|
||||
// Test Vary header Cors request
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set for Origin")
|
||||
}
|
||||
|
||||
// go test -run -v Test_CORS_Wildcard
|
||||
func Test_CORS_Wildcard(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
@ -97,6 +134,10 @@ func Test_CORS_Wildcard(t *testing.T) {
|
|||
|
||||
// Check result
|
||||
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
|
||||
vh := string(ctx.Response.Header.Peek(fiber.HeaderVary))
|
||||
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderOrigin), "Vary header should be set for Origin")
|
||||
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestMethod), "Vary header should be set for Access-Control-Request-Method")
|
||||
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestHeaders), "Vary header should be set for Access-Control-Request-Headers")
|
||||
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
|
||||
utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
|
||||
utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
|
||||
|
@ -105,9 +146,9 @@ func Test_CORS_Wildcard(t *testing.T) {
|
|||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
handler(ctx)
|
||||
|
||||
utils.AssertEqual(t, false, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should not be set for Origin")
|
||||
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
|
||||
utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
|
||||
}
|
||||
|
@ -147,7 +188,6 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
|
|||
// Test non OPTIONS (preflight) response headers
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
handler(ctx)
|
||||
|
||||
|
@ -466,7 +506,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
|
|||
// Get handler pointer
|
||||
handler := app.Handler()
|
||||
|
||||
t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
|
||||
t.Run("Without origin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Make request without origin header, and without Access-Control-Request-Method
|
||||
for _, method := range methods {
|
||||
|
@ -479,34 +519,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Make request with origin header, but without Access-Control-Request-Method
|
||||
for _, method := range methods {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(method)
|
||||
ctx.Request.SetRequestURI("https://example.com/")
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||
handler(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Make request without origin header, but with Access-Control-Request-Method
|
||||
for _, method := range methods {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(method)
|
||||
ctx.Request.SetRequestURI("https://example.com/")
|
||||
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
handler(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Make preflight request with origin header and with Access-Control-Request-Method
|
||||
|
@ -524,7 +536,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
|
||||
t.Run("Non-preflight request with origin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Make non-preflight request with origin header and with Access-Control-Request-Method
|
||||
for _, method := range methods {
|
||||
|
@ -532,7 +544,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
|
|||
ctx.Request.Header.SetMethod(method)
|
||||
ctx.Request.SetRequestURI("https://example.com/api/action")
|
||||
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
|
||||
handler(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
|
||||
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
|
||||
|
@ -901,7 +912,6 @@ func Benchmark_CORS_NewHandler(b *testing.B) {
|
|||
req.Header.SetMethod(fiber.MethodGet)
|
||||
req.SetRequestURI("/")
|
||||
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
|
||||
|
||||
ctx.Init(req, nil, nil)
|
||||
|
@ -942,7 +952,6 @@ func Benchmark_CORS_NewHandlerParallel(b *testing.B) {
|
|||
req.Header.SetMethod(fiber.MethodGet)
|
||||
req.SetRequestURI("/")
|
||||
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
|
||||
|
||||
ctx.Init(req, nil, nil)
|
||||
|
@ -976,7 +985,6 @@ func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
|
|||
req.Header.SetMethod(fiber.MethodGet)
|
||||
req.SetRequestURI("/")
|
||||
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
|
||||
|
||||
ctx.Init(req, nil, nil)
|
||||
|
@ -1017,7 +1025,6 @@ func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) {
|
|||
req.Header.SetMethod(fiber.MethodGet)
|
||||
req.SetRequestURI("/")
|
||||
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
|
||||
|
||||
ctx.Init(req, nil, nil)
|
||||
|
@ -1051,7 +1058,6 @@ func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
|
|||
req.Header.SetMethod(fiber.MethodGet)
|
||||
req.SetRequestURI("/")
|
||||
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
|
||||
|
||||
ctx.Init(req, nil, nil)
|
||||
|
@ -1092,7 +1098,6 @@ func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) {
|
|||
req.Header.SetMethod(fiber.MethodGet)
|
||||
req.SetRequestURI("/")
|
||||
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
|
||||
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
|
||||
|
||||
ctx.Init(req, nil, nil)
|
||||
|
@ -1122,6 +1127,7 @@ func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
|
|||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Preflight request
|
||||
req := &fasthttp.Request{}
|
||||
req.Header.SetMethod(fiber.MethodOptions)
|
||||
req.SetRequestURI("/")
|
||||
|
|
Loading…
Reference in New Issue