Merge pull request from GHSA-fmg4-x8pw-hjhg

* Enforce Wildcard Origins with AllowCredentials check

* Expand unit-tests, fix issues with subdomains logic, update docs

* Update cors.md

* Added test using localhost, ipv4, and ipv6 address

* improve documentation markdown

---------

Co-authored-by: René Werner <rene@gofiber.io>
pull/2883/head
Juan Calderon-Perez 2024-02-21 08:47:33 -05:00 committed by GitHub
parent 5e30112d08
commit f0cd3b44b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 404 additions and 90 deletions

View File

@ -10,6 +10,8 @@ The middleware conforms to the `access-control-allow-origin` specification by pa
For more control, `AllowOriginsFunc` can be used to programatically determine if an origin is allowed. If no match was found in `AllowOrigins` and if `AllowOriginsFunc` returns true then the 'access-control-allow-origin' response header is set to the 'origin' request header.
When defining your Origins make sure they are properly formatted. The middleware validates and normalizes the provided origins, ensuring they're in the correct format by checking for valid schemes (http or https), and removing any trailing slashes.
## Signatures
```go
@ -56,18 +58,27 @@ app.Use(cors.New(cors.Config{
}))
```
**Note: The following configuration is considered insecure and will result in a panic.**
```go
app.Use(cors.New(cors.Config{
AllowOrigins: "*",
AllowCredentials: true,
}))
```
## Config
| Property | Type | Description | Default |
|:-----------------|:---------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------|
| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| AllowOriginsFunc | `func(origin string) bool` | AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' response header to the 'origin' request header when returned true. | `nil` |
| AllowOrigins | `string` | AllowOrigin defines a comma separated list of origins that may access the resource. | `"*"` |
| AllowMethods | `string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET,POST,HEAD,PUT,DELETE,PATCH"` |
| AllowHeaders | `string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `""` |
| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. | `false` |
| ExposeHeaders | `string` | ExposeHeaders defines a whitelist headers that clients are allowed to access. | `""` |
| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, Access-Control-Max-Age header will not be added and browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. | `0` |
| Property | Type | Description | Default |
|:-----------------|:---------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------|
| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| AllowOriginsFunc | `func(origin string) bool` | AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' response header to the 'origin' request header when returned true. This allows for dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins will be not have the 'access-control-allow-credentials' header set to 'true'. | `nil` |
| AllowOrigins | `string` | AllowOrigin defines a comma separated list of origins that may access the resource. | `"*"` |
| AllowMethods | `string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET,POST,HEAD,PUT,DELETE,PATCH"` |
| AllowHeaders | `string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `""` |
| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard ("*") to prevent security vulnerabilities. | `false` |
| ExposeHeaders | `string` | ExposeHeaders defines a whitelist headers that clients are allowed to access. | `""` |
| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, Access-Control-Max-Age header will not be added and browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. | `0` |
## Default Config

View File

@ -16,12 +16,14 @@ type Config struct {
Next func(c *fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// response header to the 'origin' request header when returned true.
// response header to the 'origin' request header when returned true. This allows for
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
// will be not have the 'access-control-allow-credentials' header set to 'true'.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
// AllowOrigin defines a list of origins that may access the resource.
// AllowOrigin defines a comma separated list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string
@ -41,7 +43,8 @@ type Config struct {
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
// actual request can be made using credentials. Note: If true, AllowOrigins
// cannot be set to a wildcard ("*") to prevent security vulnerabilities.
//
// Optional. Default value false.
AllowCredentials bool
@ -105,6 +108,26 @@ func New(config ...Config) fiber.Handler {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}
// Validate CORS credentials configuration
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}
// Validate and normalize static AllowOrigins if not using AllowOriginsFunc
if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
validatedOrigins := []string{}
for _, origin := range strings.Split(cfg.AllowOrigins, ",") {
isValid, normalizedOrigin := normalizeOrigin(origin)
if isValid {
validatedOrigins = append(validatedOrigins, normalizedOrigin)
} else {
log.Warnf("[CORS] Invalid origin format in configuration: %s", origin)
panic("[CORS] Invalid origin provided in configuration")
}
}
cfg.AllowOrigins = strings.Join(validatedOrigins, ",")
}
// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")
@ -123,22 +146,18 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
// Get origin header
origin := c.Get(fiber.HeaderOrigin)
// Get originHeader header
originHeader := c.Get(fiber.HeaderOrigin)
allowOrigin := ""
// Check allowed origins
for _, o := range allowOrigins {
if o == "*" {
for _, origin := range allowOrigins {
if origin == "*" {
allowOrigin = "*"
break
}
if o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
if validateDomain(originHeader, origin) {
allowOrigin = originHeader
break
}
}
@ -147,8 +166,8 @@ func New(config ...Config) fiber.Handler {
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
if cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
}
@ -173,9 +192,17 @@ func New(config ...Config) fiber.Handler {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
// Set Allow-Credentials if set to true
if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
}
} else {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set Allow-Headers if not empty

View File

@ -35,7 +35,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)
utils.AssertEqual(t, "0", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
@ -72,7 +72,46 @@ func Test_CORS_Wildcard(t *testing.T) {
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
AllowOrigins: "*",
AllowOrigins: "*",
MaxAge: 3600,
ExposeHeaders: "X-Request-ID",
AllowHeaders: "Authentication",
}))
// Get handler pointer
handler := app.Handler()
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
// Perform request
handler(ctx)
// Check result
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Origin_AllowCredentials
func Test_CORS_Origin_AllowCredentials(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
AllowOrigins: "http://localhost",
AllowCredentials: true,
MaxAge: 3600,
ExposeHeaders: "X-Request-ID",
@ -84,14 +123,14 @@ func Test_CORS_Wildcard(t *testing.T) {
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
// Perform request
handler(ctx)
// Check result
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
utils.AssertEqual(t, "http://localhost", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
@ -105,6 +144,57 @@ func Test_CORS_Wildcard(t *testing.T) {
utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Wildcard_AllowCredentials_Panic
// Test for fiber-ghsa-fmg4-x8pw-hjhg
func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: "*",
AllowCredentials: true,
}))
}()
if !didPanic {
t.Errorf("Expected a panic when AllowOrigins is '*' and AllowCredentials is true")
}
}
// go test -run -v Test_CORS_Invalid_Origin_Panic
func Test_CORS_Invalid_Origin_Panic(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: "localhost",
AllowCredentials: true,
}))
}()
if !didPanic {
t.Errorf("Expected a panic when Origin is missing scheme")
}
}
// go test -run -v Test_CORS_Subdomain
func Test_CORS_Subdomain(t *testing.T) {
t.Parallel()
@ -193,12 +283,9 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
shouldAllowOrigin: false,
pattern: "http://*.example.com",
reqOrigin: "http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
@ -471,12 +558,13 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
}
// The fix for issue #2422
func Test_CORS_AllowCredetials(t *testing.T) {
func Test_CORS_AllowCredentials(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
Name string
Config Config
RequestOrigin string
ResponseOrigin string
ResponseCredentials string
}{
{
Name: "AllowOriginsFuncDefined",
@ -488,19 +576,35 @@ func Test_CORS_AllowCredetials(t *testing.T) {
},
RequestOrigin: "http://aaa.com",
// The AllowOriginsFunc config was defined, should use the real origin of the function
ResponseOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
ResponseCredentials: "true",
},
{
Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "*",
ResponseOrigin: "*",
// Middleware will validate that wildcard wont set credentials to true
ResponseCredentials: "",
},
{
Name: "AllowOriginsFuncNotDefined",
Config: Config{
AllowCredentials: true,
// Setting this to true will cause the middleware to panic since default AllowOrigins is "*"
AllowCredentials: false,
},
RequestOrigin: "http://aaa.com",
// None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*"
// which will cause the CORS error in the client:
// The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*'
// when the request's credentials mode is 'include'.
ResponseOrigin: "*",
ResponseOrigin: "*",
ResponseCredentials: "",
},
{
Name: "AllowOriginsDefined",
@ -508,8 +612,9 @@ func Test_CORS_AllowCredetials(t *testing.T) {
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
ResponseCredentials: "true",
},
{
Name: "AllowOriginsDefined/UnallowedOrigin",
@ -517,8 +622,9 @@ func Test_CORS_AllowCredetials(t *testing.T) {
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
ResponseCredentials: "",
},
}
@ -536,9 +642,7 @@ func Test_CORS_AllowCredetials(t *testing.T) {
handler(ctx)
if tc.Config.AllowCredentials {
utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
}
utils.AssertEqual(t, tc.ResponseCredentials, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}

View File

@ -1,56 +1,83 @@
package cors
import (
"net/url"
"strings"
)
// matchScheme compares the scheme of the domain and pattern
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
const maxDomainLen = 253
if len(domAuth) > maxDomainLen {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
const divHalf = 2
for i := len(domComp)/divHalf - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/divHalf - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
// validateDomain checks if the domain matches the pattern
func validateDomain(domain, pattern string) bool {
// Directly compare the domain and pattern for an exact match.
if domain == pattern {
return true
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
// Normalize domain and pattern to exclude schemes and ports for matching purposes
normalizedDomain := normalizeDomain(domain)
normalizedPattern := normalizeDomain(pattern)
// Handling the case where pattern is a wildcard subdomain pattern.
if strings.HasPrefix(normalizedPattern, "*.") {
// Trim leading "*." from pattern for comparison.
trimmedPattern := normalizedPattern[2:]
// Check if the domain ends with the trimmed pattern.
if strings.HasSuffix(normalizedDomain, trimmedPattern) {
// Ensure that the domain is not exactly the base domain.
if normalizedDomain != trimmedPattern {
// Special handling to prevent "example.com" matching "*.example.com".
if strings.TrimSuffix(normalizedDomain, trimmedPattern) != "" {
return true
}
}
}
}
return false
}
// normalizeDomain removes the scheme and port from the input domain
func normalizeDomain(input string) string {
// Remove scheme
input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://")
// Find and remove port, if present
if portIndex := strings.Index(input, ":"); portIndex != -1 {
input = input[:portIndex]
}
return input
}
// normalizeOrigin checks if the provided origin is in a correct format
// and normalizes it by removing any path or trailing slash.
// It returns a boolean indicating whether the origin is valid
// and the normalized origin.
func normalizeOrigin(origin string) (bool, string) {
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false, ""
}
// Validate the scheme is either http or https
if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
return false, ""
}
// Normalize the origin by constructing it from the scheme and host.
// The path or trailing slash is not included in the normalized origin.
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
}

View File

@ -0,0 +1,145 @@
package cors
import (
"testing"
)
// go test -run -v Test_normalizeOrigin
func Test_normalizeOrigin(t *testing.T) {
testCases := []struct {
origin string
expectedValid bool
expectedOrigin string
}{
{"http://example.com", true, "http://example.com"}, // Simple case should work.
{"http://example.com/", true, "http://example.com"}, // Trailing slash should be removed.
{"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved.
{"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed.
{"http://", false, ""}, // Invalid origin should not be accepted.
{"http://example.com/path", false, ""}, // Path should not be accepted.
{"http://example.com?query=123", false, ""}, // Query should not be accepted.
{"http://example.com#fragment", false, ""}, // Fragment should not be accepted.
{"http://localhost", true, "http://localhost"}, // Localhost should be accepted.
{"http://127.0.0.1", true, "http://127.0.0.1"}, // IPv4 address should be accepted.
{"http://[::1]", true, "http://[::1]"}, // IPv6 address should be accepted.
{"http://[::1]:8080", true, "http://[::1]:8080"}, // IPv6 address with port should be accepted.
{"http://[::1]:8080/", true, "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted.
{"http://[::1]:8080/path", false, ""}, // IPv6 address with port and path should not be accepted.
{"http://[::1]:8080?query=123", false, ""}, // IPv6 address with port and query should not be accepted.
{"http://[::1]:8080#fragment", false, ""}, // IPv6 address with port and fragment should not be accepted.
{"http://[::1]:8080/path?query=123#fragment", false, ""}, // IPv6 address with port, path, query, and fragment should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/", false, ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/invalid", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/invalid/", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/invalid/segment", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted.
}
for _, tc := range testCases {
valid, normalizedOrigin := normalizeOrigin(tc.origin)
if valid != tc.expectedValid {
t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid)
}
if normalizedOrigin != tc.expectedOrigin {
t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin)
}
}
}
// go test -run -v Test_matchScheme
func Test_matchScheme(t *testing.T) {
testCases := []struct {
domain string
pattern string
expected bool
}{
{"http://example.com", "http://example.com", true}, // Exact match should work.
{"https://example.com", "http://example.com", false}, // Scheme mismatch should matter.
{"http://example.com", "https://example.com", false}, // Scheme mismatch should matter.
{"http://example.com", "http://example.org", true}, // Different domains should not matter.
{"http://example.com", "http://example.com:8080", true}, // Port should not matter.
{"http://example.com:8080", "http://example.com", true}, // Port should not matter.
{"http://example.com:8080", "http://example.com:8081", true}, // Different ports should not matter.
{"http://localhost", "http://localhost", true}, // Localhost should match.
{"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match.
{"http://[::1]", "http://[::1]", true}, // IPv6 address should match.
}
for _, tc := range testCases {
result := matchScheme(tc.domain, tc.pattern)
if result != tc.expected {
t.Errorf("Expected matchScheme('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result)
}
}
}
// go test -run -v Test_validateOrigin
func Test_validateOrigin(t *testing.T) {
testCases := []struct {
domain string
pattern string
expected bool
}{
{"http://example.com", "http://example.com", true}, // Exact match should work.
{"https://example.com", "http://example.com", false}, // Scheme mismatch should matter in CORS context.
{"http://example.com", "https://example.com", false}, // Scheme mismatch should matter in CORS context.
{"http://example.com", "http://example.org", false}, // Different domains should not match.
{"http://example.com", "http://example.com:8080", false}, // Port mismatch should matter.
{"http://example.com:8080", "http://example.com", false}, // Port mismatch should matter.
{"http://example.com:8080", "http://example.com:8081", false}, // Different ports should not match.
{"example.com", "example.com", true}, // Simplified form, assuming scheme and port are not considered here, but in practice, they are part of the origin.
{"sub.example.com", "example.com", false}, // Subdomain should not match the base domain directly.
{"sub.example.com", "*.example.com", true}, // Correct assumption for wildcard subdomain matching.
{"example.com", "*.example.com", false}, // Base domain should not match its wildcard subdomain pattern.
{"sub.example.com", "*.com", true}, // Technically correct for pattern matching, but broad wildcard use like this is not recommended for CORS.
{"sub.sub.example.com", "*.example.com", true}, // Nested subdomain should match the wildcard pattern.
{"example.com", "*.org", false}, // Different TLDs should not match.
{"example.com", "example.org", false}, // Different domains should not match.
{"example.com:8080", "*.example.com", false}, // Different ports mean different origins.
{"example.com", "sub.example.net", false}, // Different domains should not match.
{"http://localhost", "http://localhost", true}, // Localhost should match.
{"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match.
{"http://[::1]", "http://[::1]", true}, // IPv6 address should match.
}
for _, tc := range testCases {
result := validateDomain(tc.domain, tc.pattern)
if result != tc.expected {
t.Errorf("Expected validateOrigin('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result)
}
}
}
// go test -run -v Test_normalizeDomain
func Test_normalizeDomain(t *testing.T) {
testCases := []struct {
input string
expectedOutput string
}{
{"http://example.com", "example.com"}, // Simple case with http scheme.
{"https://example.com", "example.com"}, // Simple case with https scheme.
{"http://example.com:3000", "example.com"}, // Case with port.
{"https://example.com:3000", "example.com"}, // Case with port and https scheme.
{"http://example.com/path", "example.com/path"}, // Case with path.
{"http://example.com?query=123", "example.com?query=123"}, // Case with query.
{"http://example.com#fragment", "example.com#fragment"}, // Case with fragment.
{"example.com", "example.com"}, // Case without scheme.
{"example.com:8080", "example.com"}, // Case without scheme but with port.
{"sub.example.com", "sub.example.com"}, // Case with subdomain.
{"sub.sub.example.com", "sub.sub.example.com"}, // Case with nested subdomain.
{"http://localhost", "localhost"}, // Case with localhost.
{"http://127.0.0.1", "127.0.0.1"}, // Case with IPv4 address.
{"http://[::1]", "[::1]"}, // Case with IPv6 address.
}
for _, tc := range testCases {
output := normalizeDomain(tc.input)
if output != tc.expectedOutput {
t.Errorf("Expected normalized domain '%s' for input '%s', but got: '%s'", tc.expectedOutput, tc.input, output)
}
}
}