fiber/middleware/cors/cors_test.go

649 lines
18 KiB
Go

package cors
import (
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_CORS_Defaults(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
testDefaultOrEmptyConfig(t, app)
}
func Test_CORS_Empty_Config(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{}))
testDefaultOrEmptyConfig(t, app)
}
func Test_CORS_Negative_MaxAge(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{MaxAge: -1}))
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)
require.Equal(t, "0", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
t.Helper()
h := app.Handler()
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
// Test default OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
h(ctx)
require.Equal(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)))
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
// go test -run -v Test_CORS_Wildcard
func Test_CORS_Wildcard(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
AllowOrigins: "*",
MaxAge: 3600,
ExposeHeaders: "X-Request-ID",
AllowHeaders: "Authentication",
}))
// Get handler pointer
handler := app.Handler()
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
// Perform request
handler(ctx)
// Check result
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Origin_AllowCredentials
func Test_CORS_Origin_AllowCredentials(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
AllowOrigins: "http://localhost",
AllowCredentials: true,
MaxAge: 3600,
ExposeHeaders: "X-Request-ID",
AllowHeaders: "Authentication",
}))
// Get handler pointer
handler := app.Handler()
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
// Perform request
handler(ctx)
// Check result
require.Equal(t, "http://localhost", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
require.Equal(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
require.Equal(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Wildcard_AllowCredentials_Panic
// Test for fiber-ghsa-fmg4-x8pw-hjhg
func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: "*",
AllowCredentials: true,
}))
}()
if !didPanic {
t.Errorf("Expected a panic when AllowOrigins is '*' and AllowCredentials is true")
}
}
// go test -run -v Test_CORS_Invalid_Origin_Panic
func Test_CORS_Invalid_Origin_Panic(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: "localhost",
AllowCredentials: true,
}))
}()
if !didPanic {
t.Errorf("Expected a panic when Origin is missing scheme")
}
}
// go test -run -v Test_CORS_Subdomain
func Test_CORS_Subdomain(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain
app.Use("/", New(Config{AllowOrigins: "http://*.example.com"}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
handler(ctx)
require.Equal(t, "http://test.example.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginScheme(t *testing.T) {
t.Parallel()
tests := []struct {
reqOrigin, pattern string
shouldAllowOrigin bool
}{
{
pattern: "http://example.com",
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: "https://example.com",
reqOrigin: "https://example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
reqOrigin: "https://example.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://bbb.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://*.aaa.example.com",
reqOrigin: "http://bbb.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://*.example.com:8080",
reqOrigin: "http://aaa.example.com:8080",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
reqOrigin: "http://gofiber.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.aaa.example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: "https://*--aaa.bbb.com",
reqOrigin: "https://prod-preview--aaa.bbb.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://foo.[a-z]*.example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
}
for _, tt := range tests {
app := fiber.New()
app.Use("/", New(Config{AllowOrigins: tt.pattern}))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
handler(ctx)
if tt.shouldAllowOrigin {
require.Equal(t, tt.reqOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
} else {
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
}
}
// go test -run Test_CORS_Next
func Test_CORS_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOrigins: "http://example-1.com",
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")'
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
handler(ctx)
require.Equal(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
require.Equal(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be empty because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
// Allow-Origin header should be "http://example-2.com"
require.Equal(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
}{
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(_ string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(_ string) bool {
return false
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "*",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: func(_ string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "",
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
handler(ctx)
require.Equal(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}
// The fix for issue #2422
func Test_CORS_AllowCredentials(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
ResponseCredentials string
}{
{
Name: "AllowOriginsFuncDefined",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
// The AllowOriginsFunc config was defined, should use the real origin of the function
ResponseOrigin: "http://aaa.com",
ResponseCredentials: "true",
},
{
Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "*",
ResponseOrigin: "*",
// Middleware will validate that wildcard wont set credentials to true
ResponseCredentials: "",
},
{
Name: "AllowOriginsFuncNotDefined",
Config: Config{
// Setting this to true will cause the middleware to panic since default AllowOrigins is "*"
AllowCredentials: false,
},
RequestOrigin: "http://aaa.com",
// None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*"
// which will cause the CORS error in the client:
// The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*'
// when the request's credentials mode is 'include'.
ResponseOrigin: "*",
ResponseCredentials: "",
},
{
Name: "AllowOriginsDefined",
Config: Config{
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
ResponseCredentials: "true",
},
{
Name: "AllowOriginsDefined/UnallowedOrigin",
Config: Config{
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
ResponseCredentials: "",
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
handler(ctx)
require.Equal(t, tc.ResponseCredentials, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}