chore(middleware/cors): Merge changes from v2 (#2922)

* fix(middleware/cors): Handling and wildcard subdomain matching (#2915)

* fix: allow origins check

Refactor CORS origin validation and normalization to trim leading or trailing whitespace in the cfg.AllowOrigins string [list]. URLs with whitespace inside the URL are invalid, so the normalizeOrigin will return false because url.Parse will fail, and the middleware will panic.

fixes #2882

* test: AllowOrigins with whitespace

* test(middleware/cors): add benchmarks

* chore: fix linter errors

* test(middleware/cors): use h() instead of app.Test()

* test(middleware/cors): add miltiple origins in Test_CORS_AllowOriginScheme

* chore: refactor validate and normalize

* test(cors/middleware): add more benchmarks

* fix(middleware/cors): handling and wildcard subdomain matching

docs(middleware/cors): add How it works and Security Considerations

* chore: grammar

* Apply suggestions from code review

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* chore: fix misspelling

* test(middleware/cors): combine Invalid_Origins tests

* refactor(middleware/cors): headers handling

* docs(middleware/cors): Update AllowOrigins description

* chore: merge

* perf(middleware/cors): optimize handler

* perf(middleware/cors): optimize handler

* chore(middleware/cors): ipdate origin handling logic

* chore(middleware/cors): fix header capitalization

* docs(middleware/cors): improve sercuity notes

* docs(middleware/cors): Improve security notes

* docs(middleware/cors): improve CORS overview

* docs(middleware/cors): fix ordering of how it works

* docs(middleware/cors): add additional info to How to works

* docs(middleware/cors): rm space

* docs(middleware/cors): add validation for AllowOrigins origins to overview

* docs(middleware/cors): update ExposeHeaders and MaxAge descriptions

* docs(middleware/cors): Add dynamic origin validation example

* docs(middleware/cors): Improve security notes and fix header capitalization

* docs(middleware/cors): configuration examples

* docs(middleware/cors): `"*"`

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* test(middleware/cors): improve test coverage for request types

* chore(middleware/cors): fix v2 merge issues

* test(middleware/cors): Add subdomain matching tests

* fix(middleware/cors): Update Next function signature

* test(middleware/cors): Add benchmark for CORS subdomain matching

* test(middleware/cors): cover additiona test cases

* refactor(middleware/cors): origin validation and normalization

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
pull/2929/head
Jason McNeil 2024-03-19 04:32:19 -03:00 committed by GitHub
parent 43dc60fb27
commit 7fa8b2d4ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 543 additions and 174 deletions

View File

@ -4,13 +4,15 @@ id: cors
# CORS
CORS middleware for [Fiber](https://github.com/gofiber/fiber) that can be used to enable [Cross-Origin Resource Sharing](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) with various options.
CORS (Cross-Origin Resource Sharing) is a middleware for [Fiber](https://github.com/gofiber/fiber) that allows servers to specify who can access its resources and how. It's not a security feature, but a way to relax the security model of web browsers for cross-origin requests. You can learn more about CORS on [Mozilla Developer Network](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS).
The middleware conforms to the `access-control-allow-origin` specification by parsing `AllowOrigins`. First, the middleware checks if there is a matching allowed origin for the requesting 'origin' header. If there is a match, it returns exactly one matching domain from the list of allowed origins.
This middleware works by adding CORS headers to responses from your Fiber application. These headers specify which origins, methods, and headers are allowed for cross-origin requests. It also handles preflight requests, which are a CORS mechanism to check if the actual request is safe to send.
For more control, `AllowOriginsFunc` can be used to programatically determine if an origin is allowed. If no match was found in `AllowOrigins` and if `AllowOriginsFunc` returns true then the 'access-control-allow-origin' response header is set to the 'origin' request header.
The middleware uses the `AllowOrigins` option to control which origins can make cross-origin requests. It supports single origin, multiple origins, subdomain matching, and wildcard origin. It also allows programmatic origin validation with the `AllowOriginsFunc` option.
When defining your Origins make sure they are properly formatted. The middleware validates and normalizes the provided origins, ensuring they're in the correct format by checking for valid schemes (http or https), and removing any trailing slashes.
To ensure that the provided `AllowOrigins` origins are correctly formatted, this middleware validates and normalizes them. It checks for valid schemes, i.e., HTTP or HTTPS, and it will automatically remove trailing slashes. If the provided origin is invalid, the middleware will panic.
When configuring CORS, it's important to avoid [common pitfalls](#common-pitfalls) like using a wildcard origin with credentials, being overly permissive with origins, and inadequate validation with `AllowOriginsFunc`. Misconfiguration can expose your application to various security risks.
## Signatures
@ -31,6 +33,16 @@ import (
After you initiate your Fiber app, you can use the following possibilities:
### Basic usage
To use the default configuration, simply use `cors.New()`. This will allow wildcard origins '*', all methods, no credentials, and no headers or exposed headers.
```go
app.Use(cors.New())
```
### Custom configuration (specific origins, headers, etc.)
```go
// Initialize default config
app.Use(cors.New())
@ -38,27 +50,50 @@ app.Use(cors.New())
// Or extend your config for customization
app.Use(cors.New(cors.Config{
AllowOrigins: "https://gofiber.io, https://gofiber.net",
AllowHeaders: "Origin, Content-Type, Accept",
AllowHeaders: "Origin, Content-Type, Accept",
}))
```
Using the `AllowOriginsFunc` function. In this example any origin will be allowed via CORS.
### Dynamic origin validation
For example, if a browser running on `http://localhost:3000` sends a request, this will be accepted and the `access-control-allow-origin` response header will be set to `http://localhost:3000`.
You can use `AllowOriginsFunc` to programmatically determine whether to allow a request based on its origin. This is useful when you need to validate origins against a database or other dynamic sources. The function should return `true` if the origin is allowed, and `false` otherwise.
**Note: Using this feature is discouraged in production and it's best practice to explicitly set CORS origins via `AllowOrigins`.**
Be sure to review the [security considerations](#security-considerations) when using `AllowOriginsFunc`.
:::caution
Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats.
If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`.
:::
```go
app.Use(cors.New())
// dbCheckOrigin checks if the origin is in the list of allowed origins in the database.
func dbCheckOrigin(db *sql.DB, origin string) bool {
// Placeholder query - adjust according to your database schema and query needs
query := "SELECT COUNT(*) FROM allowed_origins WHERE origin = $1"
var count int
err := db.QueryRow(query, origin).Scan(&count)
if err != nil {
// Handle error (e.g., log it); for simplicity, we return false here
return false
}
return count > 0
}
// ...
app.Use(cors.New(cors.Config{
AllowOriginsFunc: func(origin string) bool {
return os.Getenv("ENVIRONMENT") == "development"
},
AllowOriginsFunc: func(origin string) bool {
return dbCheckOrigin(db, origin)
},
}))
```
**Note: The following configuration is considered insecure and will result in a panic.**
### Prohibited usage
The following example is prohibited because it can expose your application to security risks. It sets `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true`.
```go
app.Use(cors.New(cors.Config{
@ -67,18 +102,24 @@ app.Use(cors.New(cors.Config{
}))
```
This will result in the following panic:
```
panic: [CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to `"*"`.
```
## Config
| Property | Type | Description | Default |
|:-----------------|:---------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| AllowOriginsFunc | `func(origin string) bool` | AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' response header to the 'origin' request header when returned true. This allows for dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins will be not have the 'access-control-allow-credentials' header set to 'true'. | `nil` |
| AllowOrigins | `string` | AllowOrigin defines a comma separated list of origins that may access the resource. | `"*"` |
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| AllowOriginsFunc | `func(origin string) bool` | `AllowOriginsFunc` is a function that dynamically determines whether to allow a request based on its origin. If this function returns `true`, the 'Access-Control-Allow-Origin' response header will be set to the request's 'origin' header. This function is only used if the request's origin doesn't match any origin in `AllowOrigins`. | `nil` |
| AllowOrigins | `string` | AllowOrigins defines a comma separated list of origins that may access the resource. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `"*"` |
| AllowMethods | `string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET,POST,HEAD,PUT,DELETE,PATCH"` |
| AllowHeaders | `string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `""` |
| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard ("*") to prevent security vulnerabilities. | `false` |
| ExposeHeaders | `string` | ExposeHeaders defines a whitelist headers that clients are allowed to access. | `""` |
| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, Access-Control-Max-Age header will not be added and browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. | `0` |
| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard (`"*"`) to prevent security vulnerabilities. | `false` |
| ExposeHeaders | `string` | ExposeHeaders defines whitelist headers that clients are allowed to access. | `""` |
| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, the Access-Control-Max-Age header will not be added and the browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header to 0. | `0` |
## Default Config
@ -101,3 +142,73 @@ var ConfigDefault = Config{
MaxAge: 0,
}
```
## Subdomain Matching
The `AllowOrigins` configuration supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`.
### Example
If you want to allow CORS requests from any subdomain of `example.com`, including nested subdomains, you can configure the `AllowOrigins` like so:
```go
app.Use(cors.New(cors.Config{
AllowOrigins: "https://*.example.com",
}))
```
# How It Works
The CORS middleware works by adding the necessary CORS headers to responses from your Fiber application. These headers tell browsers what origins, methods, and headers are allowed for cross-origin requests.
When a request comes in, the middleware first checks if it's a preflight request, which is a CORS mechanism to determine whether the actual request is safe to send. Preflight requests are HTTP OPTIONS requests with specific CORS headers. If it's a preflight request, the middleware responds with the appropriate CORS headers and ends the request.
If it's not a preflight request, the middleware adds the CORS headers to the response and passes the request to the next handler. The actual CORS headers added depend on the configuration of the middleware.
The `AllowOrigins` option controls which origins can make cross-origin requests. The middleware handles different `AllowOrigins` configurations as follows:
- **Single origin:** If `AllowOrigins` is set to a single origin like `"http://www.example.com"`, and that origin matches the origin of the incoming request, the middleware adds the header `Access-Control-Allow-Origin: http://www.example.com` to the response.
- **Multiple origins:** If `AllowOrigins` is set to multiple origins like `"https://example.com, https://www.example.com"`, the middleware picks the origin that matches the origin of the incoming request.
- **Subdomain matching:** If `AllowOrigins` includes `"https://*.example.com"`, a subdomain like `https://sub.example.com` will be matched and `"https://sub.example.com"` will be the header. This will also match `https://sub.sub.example.com` and so on, but not `https://example.com`.
- **Wildcard origin:** If `AllowOrigins` is set to `"*"`, the middleware uses that and adds the header `Access-Control-Allow-Origin: *` to the response.
In all cases above, except the **Wildcard origin**, the middleware will either add the `Access-Control-Allow-Origin` header to the response matching the origin of the incoming request, or it will not add the header at all if the origin is not allowed.
- **Programmatic origin validation:**: The middleware also handles the `AllowOriginsFunc` option, which allows you to programmatically determine if an origin is allowed. If `AllowOriginsFunc` returns `true` for an origin, the middleware sets the `Access-Control-Allow-Origin` header to that origin.
The `AllowMethods` option controls which HTTP methods are allowed. For example, if `AllowMethods` is set to `"GET, POST"`, the middleware adds the header `Access-Control-Allow-Methods: GET, POST` to the response.
The `AllowHeaders` option specifies which headers are allowed in the actual request. The middleware sets the Access-Control-Allow-Headers response header to the value of `AllowHeaders`. This informs the client which headers it can use in the actual request.
The `AllowCredentials` option indicates whether the response to the request can be exposed when the credentials flag is true. If `AllowCredentials` is set to `true`, the middleware adds the header `Access-Control-Allow-Credentials: true` to the response. To prevent security vulnerabilities, `AllowCredentials` cannot be set to `true` if `AllowOrigins` is set to a wildcard (`*`).
The `ExposeHeaders` option defines a whitelist of headers that clients are allowed to access. If `ExposeHeaders` is set to `"X-Custom-Header"`, the middleware adds the header `Access-Control-Expose-Headers: X-Custom-Header` to the response.
The `MaxAge` option indicates how long the results of a preflight request can be cached. If `MaxAge` is set to `3600`, the middleware adds the header `Access-Control-Max-Age: 3600` to the response.
The `Vary` header is used in this middleware to inform the client that the server's response to a request. For or both preflight and actual requests, the Vary header is set to `Access-Control-Request-Method` and `Access-Control-Request-Headers`. For preflight requests, the Vary header is also set to `Origin`. The `Vary` header is important for caching. It helps caches (like a web browser's cache or a CDN) determine when a cached response can be used in response to a future request, and when the server needs to be queried for a new response.
## Security Considerations
When configuring CORS, misconfiguration can potentially expose your application to various security risks. Here are some secure configurations and common pitfalls to avoid:
### Secure Configurations
- **Specify Allowed Origins**: Instead of using a wildcard (`"*"`), specify the exact domains allowed to make requests. For example, `AllowOrigins: "https://www.example.com, https://api.example.com"` ensures only these domains can make cross-origin requests to your application.
- **Use Credentials Carefully**: If your application needs to support credentials in cross-origin requests, ensure `AllowCredentials` is set to `true` and specify exact origins in `AllowOrigins`. Do not use a wildcard origin in this case.
- **Limit Exposed Headers**: Only whitelist headers that are necessary for the client-side application by setting `ExposeHeaders` appropriately. This minimizes the risk of exposing sensitive information.
### Common Pitfalls
- **Wildcard Origin with Credentials**: Setting `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true` is a common misconfiguration. This combination is prohibited because it can expose your application to security risks.
- **Overly Permissive Origins**: Specifying too many origins or using overly broad patterns (e.g., `https://*.example.com`) can inadvertently allow malicious sites to interact with your application. Be as specific as possible with allowed origins.
- **Inadequate `AllowOriginsFunc` Validation**: When using `AllowOriginsFunc` for dynamic origin validation, ensure the function includes robust checks to prevent unauthorized origins from being accepted. Overly permissive validation can lead to security vulnerabilities. Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats. If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`.
Remember, the key to secure CORS configuration is specificity and caution. By carefully selecting which origins, methods, and headers are allowed, you can help protect your application from cross-origin attacks.

View File

@ -15,10 +15,10 @@ type Config struct {
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
// response header to the 'origin' request header when returned true. This allows for
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
// will be not have the 'access-control-allow-credentials' header set to 'true'.
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
@ -110,31 +110,38 @@ func New(config ...Config) fiber.Handler {
// Validate CORS credentials configuration
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
log.Panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") //nolint:revive // we want to exit the program
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
var allowOrigins []string
allowOrigins := []string{}
allowSOrigins := []subdomain{}
allowAllOrigins := false
// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
origins := strings.Split(cfg.AllowOrigins, ",")
allowOrigins = make([]string, len(origins))
for i, origin := range origins {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
log.Panicf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) //nolint:revive // we want to exit the program
for _, origin := range origins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
allowSOrigins = append(allowSOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
allowOrigins = append(allowOrigins, normalizedOrigin)
}
allowOrigins[i] = normalizedOrigin
}
} else {
// If AllowOrigins is set to a wildcard or not set,
// set allowOrigins to a slice with a single element
allowOrigins = []string{cfg.AllowOrigins}
} else if cfg.AllowOrigins == "*" {
allowAllOrigins = true
}
// Strip white spaces
@ -153,18 +160,37 @@ func New(config ...Config) fiber.Handler {
}
// Get originHeader header
originHeader := c.Get(fiber.HeaderOrigin)
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
return c.Next()
}
// Set default allowOrigin to empty string
allowOrigin := ""
// Check allowed origins
for _, origin := range allowOrigins {
if origin == "*" {
allowOrigin = "*"
break
if allowAllOrigins {
allowOrigin = "*"
} else {
// Check if the origin is in the list of allowed origins
for _, origin := range allowOrigins {
if origin == originHeader {
allowOrigin = originHeader
break
}
}
if validateDomain(originHeader, origin) {
allowOrigin = originHeader
break
// Check if the origin is in the list of allowed subdomains
if allowOrigin == "" {
for _, sOrigin := range allowSOrigins {
if sOrigin.match(originHeader) {
allowOrigin = originHeader
break
}
}
}
}
@ -176,57 +202,65 @@ func New(config ...Config) fiber.Handler {
}
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
c.Vary(fiber.HeaderOrigin)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}
// Preflight request
c.Vary(fiber.HeaderOrigin)
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
log.Warn("[CORS] 'AllowCredentials' is true. Ensure 'AllowOrigins' is not set to '*' in the configuration.")
}
} else {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Set MaxAge is set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}
// Function to set CORS headers
func setCORSHeaders(c fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
c.Vary(fiber.HeaderOrigin)
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set Allow-Methods if not empty
if allowMethods != "" {
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
}
// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Set MaxAge if set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}
// Set Expose-Headers if not empty
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
}

View File

@ -34,6 +34,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)
@ -48,6 +49,8 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
@ -57,6 +60,8 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Equal(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)))
@ -84,6 +89,7 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
// Perform request
handler(ctx)
@ -97,6 +103,8 @@ func Test_CORS_Wildcard(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
@ -124,6 +132,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
// Perform request
handler(ctx)
@ -136,6 +145,8 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
@ -170,27 +181,41 @@ func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) {
}
// go test -run -v Test_CORS_Invalid_Origin_Panic
func Test_CORS_Invalid_Origin_Panic(t *testing.T) {
func Test_CORS_Invalid_Origins_Panic(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
invalidOrigins := []string{
"localhost",
"http://foo.[a-z]*.example.com",
"http://*",
"https://*",
"http://*.com*",
"invalid url",
"http://origin.com,invalid url",
// add more invalid origins as needed
}
for _, origin := range invalidOrigins {
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: origin,
AllowCredentials: true,
}))
}()
app.Use(New(Config{
AllowOrigins: "localhost",
AllowCredentials: true,
}))
}()
if !didPanic {
t.Errorf("Expected a panic when Origin is missing scheme")
if !didPanic {
t.Errorf("Expected a panic for invalid origin: %s", origin)
}
}
}
@ -209,6 +234,7 @@ func Test_CORS_Subdomain(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
@ -220,9 +246,23 @@ func Test_CORS_Subdomain(t *testing.T) {
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with domain only (disallowed)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
handler(ctx)
@ -241,6 +281,11 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: "HTTP://EXAMPLE.COM",
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: "https://example.com",
reqOrigin: "https://example.com",
@ -271,6 +316,11 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
reqOrigin: "http://aaa.example.com:8080",
shouldAllowOrigin: true,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://1.2.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
reqOrigin: "http://gofiber.com",
@ -292,7 +342,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
shouldAllowOrigin: false,
},
{
pattern: "https://*--aaa.bbb.com",
pattern: "https://--aaa.bbb.com",
reqOrigin: "https://prod-preview--aaa.bbb.com",
shouldAllowOrigin: false,
},
@ -302,8 +352,13 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
shouldAllowOrigin: true,
},
{
pattern: "http://foo.[a-z]*.example.com",
reqOrigin: "http://ccc.bbb.example.com",
pattern: "http://domain-1.com, http://example.com",
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://domain-1.com, http://example.com",
reqOrigin: "http://domain-2.com",
shouldAllowOrigin: false,
},
{
@ -332,6 +387,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
handler(ctx)
@ -344,6 +400,35 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
}
}
func Test_CORS_AllowOriginHeader_NoMatch(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOrigins: "http://example-1.com, https://example-1.com",
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
var headerExists bool
ctx.Response.Header.VisitAll(func(key, _ []byte) {
if string(key) == fiber.HeaderAccessControlAllowOrigin {
headerExists = true
}
})
require.False(t, headerExists, "Access-Control-Allow-Origin header should not be set")
}
// go test -run Test_CORS_Next
func Test_CORS_Next(t *testing.T) {
t.Parallel()
@ -359,6 +444,103 @@ func Test_CORS_Next(t *testing.T) {
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_CORS_Headers_BasedOnRequestType
func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{}))
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
methods := []string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
fiber.MethodHead,
}
// Get handler pointer
handler := app.Handler()
t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})
t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request with origin header, but without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})
t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request without origin header, but with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})
t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
require.Equal(t, 204, ctx.Response.StatusCode(), "Status code should be 204")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
require.Equal(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)")
}
})
t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/api/action")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)")
}
})
}
func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
@ -377,6 +559,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
@ -391,6 +574,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
handler(ctx)
@ -403,6 +587,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
@ -442,6 +627,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
@ -589,6 +775,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
handler(ctx)
@ -679,6 +866,7 @@ func Test_CORS_AllowCredentials(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
handler(ctx)

View File

@ -12,37 +12,6 @@ func matchScheme(domain, pattern string) bool {
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// validateDomain checks if the domain matches the pattern
func validateDomain(domain, pattern string) bool {
// Directly compare the domain and pattern for an exact match.
if domain == pattern {
return true
}
// Normalize domain and pattern to exclude schemes and ports for matching purposes
normalizedDomain := normalizeDomain(domain)
normalizedPattern := normalizeDomain(pattern)
// Handling the case where pattern is a wildcard subdomain pattern.
if strings.HasPrefix(normalizedPattern, "*.") {
// Trim leading "*." from pattern for comparison.
trimmedPattern := normalizedPattern[2:]
// Check if the domain ends with the trimmed pattern.
if strings.HasSuffix(normalizedDomain, trimmedPattern) {
// Ensure that the domain is not exactly the base domain.
if normalizedDomain != trimmedPattern {
// Special handling to prevent "example.com" matching "*.example.com".
if strings.TrimSuffix(normalizedDomain, trimmedPattern) != "" {
return true
}
}
}
}
return false
}
// normalizeDomain removes the scheme and port from the input domain
func normalizeDomain(input string) string {
// Remove scheme
@ -73,6 +42,13 @@ func normalizeOrigin(origin string) (bool, string) {
return false, ""
}
// Don't allow a wildcard with a protocol
// wildcards cannot be used within any other value. For example, the following header is not valid:
// Access-Control-Allow-Origin: https://*
if strings.Contains(parsedOrigin.Host, "*") {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
@ -83,3 +59,13 @@ func normalizeOrigin(origin string) (bool, string) {
// The path or trailing slash is not included in the normalized origin.
return true, strings.ToLower(parsedOrigin.Scheme + "://" + parsedOrigin.Host)
}
type subdomain struct {
// The wildcard pattern
prefix string
suffix string
}
func (s subdomain) match(o string) bool {
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
}

View File

@ -2,6 +2,8 @@ package cors
import (
"testing"
"github.com/stretchr/testify/assert"
)
// go test -run -v Test_normalizeOrigin
@ -16,6 +18,9 @@ func Test_normalizeOrigin(t *testing.T) {
{"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved.
{"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed.
{"http://", false, ""}, // Invalid origin should not be accepted.
{"file:///etc/passwd", false, ""}, // File scheme should not be accepted.
{"https://*example.com", false, ""}, // Wildcard domain should not be accepted.
{"http://*.example.com", false, ""}, // Wildcard subdomain should not be accepted.
{"http://example.com/path", false, ""}, // Path should not be accepted.
{"http://example.com?query=123", false, ""}, // Query should not be accepted.
{"http://example.com#fragment", false, ""}, // Fragment should not be accepted.
@ -75,44 +80,6 @@ func Test_matchScheme(t *testing.T) {
}
}
// go test -run -v Test_validateOrigin
func Test_validateOrigin(t *testing.T) {
testCases := []struct {
domain string
pattern string
expected bool
}{
{"http://example.com", "http://example.com", true}, // Exact match should work.
{"https://example.com", "http://example.com", false}, // Scheme mismatch should matter in CORS context.
{"http://example.com", "https://example.com", false}, // Scheme mismatch should matter in CORS context.
{"http://example.com", "http://example.org", false}, // Different domains should not match.
{"http://example.com", "http://example.com:8080", false}, // Port mismatch should matter.
{"http://example.com:8080", "http://example.com", false}, // Port mismatch should matter.
{"http://example.com:8080", "http://example.com:8081", false}, // Different ports should not match.
{"example.com", "example.com", true}, // Simplified form, assuming scheme and port are not considered here, but in practice, they are part of the origin.
{"sub.example.com", "example.com", false}, // Subdomain should not match the base domain directly.
{"sub.example.com", "*.example.com", true}, // Correct assumption for wildcard subdomain matching.
{"example.com", "*.example.com", false}, // Base domain should not match its wildcard subdomain pattern.
{"sub.example.com", "*.com", true}, // Technically correct for pattern matching, but broad wildcard use like this is not recommended for CORS.
{"sub.sub.example.com", "*.example.com", true}, // Nested subdomain should match the wildcard pattern.
{"example.com", "*.org", false}, // Different TLDs should not match.
{"example.com", "example.org", false}, // Different domains should not match.
{"example.com:8080", "*.example.com", false}, // Different ports mean different origins.
{"example.com", "sub.example.net", false}, // Different domains should not match.
{"http://localhost", "http://localhost", true}, // Localhost should match.
{"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match.
{"http://[::1]", "http://[::1]", true}, // IPv6 address should match.
}
for _, tc := range testCases {
result := validateDomain(tc.domain, tc.pattern)
if result != tc.expected {
t.Errorf("Expected validateOrigin('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result)
}
}
}
// go test -run -v Test_normalizeDomain
func Test_normalizeDomain(t *testing.T) {
testCases := []struct {
@ -143,3 +110,86 @@ func Test_normalizeDomain(t *testing.T) {
}
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4
func Benchmark_CORS_SubdomainMatch(b *testing.B) {
s := subdomain{
prefix: "www",
suffix: ".example.com",
}
o := "www.example.com"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s.match(o)
}
}
func Test_CORS_SubdomainMatch(t *testing.T) {
tests := []struct {
name string
sub subdomain
origin string
expected bool
}{
{
name: "match with different scheme",
sub: subdomain{prefix: "http://api.", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: false,
},
{
name: "match with different scheme",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "http://api.service.example.com",
expected: false,
},
{
name: "match with valid subdomain",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: true,
},
{
name: "match with valid nested subdomain",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://1.2.api.service.example.com",
expected: true,
},
{
name: "no match with invalid prefix",
sub: subdomain{prefix: "https://abc.", suffix: ".example.com"},
origin: "https://service.example.com",
expected: false,
},
{
name: "no match with invalid suffix",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://api.example.org",
expected: false,
},
{
name: "no match with empty origin",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "",
expected: false,
},
{
name: "partial match not considered a match",
sub: subdomain{prefix: "https://service.", suffix: ".example.com"},
origin: "https://api.example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sub.match(tt.origin)
assert.Equal(t, tt.expected, got, "subdomain.match()")
})
}
}