fiber/middleware/cors/cors.go

215 lines
6.8 KiB
Go

package cors
import (
"strconv"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if len(cfg.AllowMethods) == 0 {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if len(cfg.AllowOrigins) > 0 && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
allowOrigins := []string{}
allowSOrigins := []subdomain{}
allowAllOrigins := false
// Validate and normalize static AllowOrigins
if len(cfg.AllowOrigins) == 0 && cfg.AllowOriginsFunc == nil {
allowAllOrigins = true
}
for _, origin := range cfg.AllowOrigins {
if origin == "*" {
allowAllOrigins = true
break
}
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := utils.Trim(origin[:i+3]+origin[i+4:], ' ')
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
allowSOrigins = append(allowSOrigins, sd)
} else {
trimmedOrigin := utils.Trim(origin, ' ')
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
allowOrigins = append(allowOrigins, normalizedOrigin)
}
}
// Validate CORS credentials configuration
if cfg.AllowCredentials && allowAllOrigins {
panic("[CORS] Configuration error: When 'AllowCredentials' is set to true, 'AllowOrigins' cannot contain a wildcard origin '*'. Please specify allowed origins explicitly or adjust 'AllowCredentials' setting.")
}
// Warn if allowAllOrigins is set to true and AllowOriginsFunc is defined
if allowAllOrigins && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] 'AllowOrigins' is set to allow all origins, 'AllowOriginsFunc' will not be used.")
}
// Convert int to string
maxAge := strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}
return c.Next()
}
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// for non-CORS OPTIONS requests:
c.Vary(fiber.HeaderOrigin)
return c.Next()
}
// Set default allowOrigin to empty string
allowOrigin := ""
// Check allowed origins
if allowAllOrigins {
allowOrigin = "*"
} else {
// Check if the origin is in the list of allowed origins
for _, origin := range allowOrigins {
if origin == originHeader {
allowOrigin = originHeader
break
}
}
// Check if the origin is in the list of allowed subdomains
if allowOrigin == "" {
for _, sOrigin := range allowSOrigins {
if sOrigin.match(originHeader) {
allowOrigin = originHeader
break
}
}
}
}
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setSimpleHeaders(c, allowOrigin, maxAge, cfg)
return c.Next()
}
// Pre-flight request
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
if cfg.AllowPrivateNetwork && c.Get(fiber.HeaderAccessControlRequestPrivateNetwork) == "true" {
c.Vary(fiber.HeaderAccessControlRequestPrivateNetwork)
c.Set(fiber.HeaderAccessControlAllowPrivateNetwork, "true")
}
c.Vary(fiber.HeaderOrigin)
setSimpleHeaders(c, allowOrigin, maxAge, cfg)
// Set Preflight headers
if len(cfg.AllowMethods) > 0 {
c.Set(fiber.HeaderAccessControlAllowMethods, strings.Join(cfg.AllowMethods, ", "))
}
if len(cfg.AllowHeaders) > 0 {
c.Set(fiber.HeaderAccessControlAllowHeaders, strings.Join(cfg.AllowHeaders, ", "))
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}
// Function to set Simple CORS headers
func setSimpleHeaders(c fiber.Ctx, allowOrigin, maxAge string, cfg Config) {
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set MaxAge if set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}
// Set Expose-Headers if not empty
if len(cfg.ExposeHeaders) > 0 {
c.Set(fiber.HeaderAccessControlExposeHeaders, strings.Join(cfg.ExposeHeaders, ", "))
}
}