feat(cors): Added new 'AllowOriginsFunc' function. (#2394)

*  feat(cors): Added new 'AllowOriginsFunc' function.

* feat(cors): Added warning log for when both 'AllowOrigins' and 'AllowOriginsFunc' are set.

* feat(docs): Updated docs to include note about discouraging the use of this function in production workloads.

---------

Co-authored-by: RW <rene@gofiber.io>
pull/2410/head
James Lucas 2023-04-11 09:24:29 +01:00 committed by GitHub
parent fcf708dfc2
commit 866d5b7628
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 37 deletions

View File

@ -35,52 +35,77 @@ app.Use(cors.New(cors.Config{
}))
```
Using the `AllowOriginsFunc` function. In this example any origin will be allowed via CORS.
For example, if a browser running on `http://localhost:3000` sends a request, this will be accepted and the `access-control-allow-origin` response header will be set to `http://localhost:3000`.
**Note: Using this feature is discouraged in production and it's best practice to explicitly set CORS origins via `AllowOrigins`.**
```go
app.Use(cors.New())
app.Use(cors.New(cors.Config{
AllowOriginsFunc: func(origin string) bool {
return os.Getenv("ENVIRONMENT") == "development"
},
}))
```
## Config
```go
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string
// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// response header to the 'origin' request header when returned true.
//
// Note: Using this feature is discouraged in production and it's best practice to explicitly
// set CORS origins via 'AllowOrigins'
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
// AllowMethods defines a list of methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string
// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
//
// Optional. Default value false.
AllowCredentials bool
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
//
// Optional. Default value false.
AllowCredentials bool
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
//
// Optional. Default value 0.
MaxAge int
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
//
// Optional. Default value 0.
MaxAge int
}
```
@ -89,6 +114,7 @@ type Config struct {
```go
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
AllowMethods: strings.Join([]string{
fiber.MethodGet,

View File

@ -1,6 +1,7 @@
package cors
import (
"log"
"strconv"
"strings"
@ -14,6 +15,12 @@ type Config struct {
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// response header to the 'origin' request header when returned true.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
@ -54,8 +61,9 @@ type Config struct {
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
AllowOrigins: "*",
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
AllowMethods: strings.Join([]string{
fiber.MethodGet,
fiber.MethodPost,
@ -88,6 +96,11 @@ func New(config ...Config) fiber.Handler {
}
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
log.Printf("[CORS] - [Warning] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.\n")
}
// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")
@ -126,6 +139,15 @@ func New(config ...Config) fiber.Handler {
}
}
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
}
}
// Simple request
if c.Method() != fiber.MethodOptions {
c.Vary(fiber.HeaderOrigin)

View File

@ -2,6 +2,7 @@ package cors
import (
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
@ -242,3 +243,54 @@ func Test_CORS_Next(t *testing.T) {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_CORS_AllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOrigins: "http://example-1.com",
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
handler(ctx)
utils.AssertEqual(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}