🚀 [Feature]: middleware/csrf custom extractor (#2052)

* feat(middleware/csrf): allow custom Extractor

* test: update Test_CSRF_From_Custom

* docs: add comma

* docs: update KeyLookup docs
pull/2058/head
Jason McNeil 2022-08-28 13:57:47 -03:00 committed by GitHub
parent 506f0b21c5
commit 6272d759eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 22 deletions

View File

@ -49,9 +49,12 @@ app.Use(csrf.New(csrf.Config{
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
Extractor: func(c *fiber.Ctx) (string, error) { ... },
}))
```
Note: KeyLookup will be ignored if Extractor is explicitly set.
### Custom Storage/Database
You can use any storage from our [storage](https://github.com/gofiber/storage/) package.
@ -74,7 +77,7 @@ type Config struct {
Next func(c *fiber.Ctx) bool
// KeyLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// to create an Extractor that extracts the token from the request.
// Possible values:
// - "header:<name>"
// - "query:<name>"
@ -82,6 +85,8 @@ type Config struct {
// - "form:<name>"
// - "cookie:<name>"
//
// Ignored if an Extractor is explicitly set.
//
// Optional. Default: "header:X-CSRF-Token"
KeyLookup string
@ -133,6 +138,13 @@ type Config struct {
//
// Optional. Default: utils.UUID
KeyGenerator func() string
// Extractor returns the csrf token
//
// If set this will be used in place of an Extractor based on KeyLookup.
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c *fiber.Ctx) (string, error)
}
```

View File

@ -18,7 +18,7 @@ type Config struct {
Next func(c *fiber.Ctx) bool
// KeyLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// to create an Extractor that extracts the token from the request.
// Possible values:
// - "header:<name>"
// - "query:<name>"
@ -26,6 +26,8 @@ type Config struct {
// - "form:<name>"
// - "cookie:<name>"
//
// Ignored if an Extractor is explicitly set.
//
// Optional. Default: "header:X-CSRF-Token"
KeyLookup string
@ -92,8 +94,12 @@ type Config struct {
// Optional. Default: DefaultErrorHandler
ErrorHandler fiber.ErrorHandler
// extractor returns the csrf token from the request based on KeyLookup
extractor func(c *fiber.Ctx) (string, error)
// Extractor returns the csrf token
//
// If set this will be used in place of an Extractor based on KeyLookup.
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c *fiber.Ctx) (string, error)
}
// ConfigDefault is the default config
@ -104,7 +110,7 @@ var ConfigDefault = Config{
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
ErrorHandler: defaultErrorHandler,
extractor: csrfFromHeader("X-Csrf-Token"),
Extractor: CsrfFromHeader("X-Csrf-Token"),
}
// default ErrorHandler that process return error from fiber.Handler
@ -174,18 +180,20 @@ func configDefault(config ...Config) Config {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}
// By default we extract from a header
cfg.extractor = csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
if cfg.Extractor == nil {
// By default we extract from a header
cfg.Extractor = CsrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
switch selectors[0] {
case "form":
cfg.extractor = csrfFromForm(selectors[1])
case "query":
cfg.extractor = csrfFromQuery(selectors[1])
case "param":
cfg.extractor = csrfFromParam(selectors[1])
case "cookie":
cfg.extractor = csrfFromCookie(selectors[1])
switch selectors[0] {
case "form":
cfg.Extractor = CsrfFromForm(selectors[1])
case "query":
cfg.Extractor = CsrfFromQuery(selectors[1])
case "param":
cfg.Extractor = CsrfFromParam(selectors[1])
case "cookie":
cfg.Extractor = CsrfFromCookie(selectors[1])
}
}
return cfg

View File

@ -39,7 +39,7 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Extract token from client request i.e. header, query, param, form or cookie
token, err = cfg.extractor(c)
token, err = cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}

View File

@ -236,6 +236,50 @@ func Test_CSRF_From_Cookie(t *testing.T) {
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
}
func Test_CSRF_From_Custom(t *testing.T) {
app := fiber.New()
extractor := func(c *fiber.Ctx) (string, error) {
body := string(c.Body())
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(body, "=")
if len(selectors) != 2 || selectors[1] == "" {
return "", errMissingParam
}
return selectors[1], nil
}
app.Use(New(Config{Extractor: extractor}))
app.Post("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
}
func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
app := fiber.New()

View File

@ -15,7 +15,7 @@ var (
)
// csrfFromParam returns a function that extracts token from the url param string.
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
@ -26,7 +26,7 @@ func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
}
// csrfFromForm returns a function that extracts a token from a multipart-form.
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
@ -37,7 +37,7 @@ func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
}
// csrfFromCookie returns a function that extracts token from the cookie header.
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
@ -48,7 +48,7 @@ func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
}
// csrfFromHeader returns a function that extracts token from the request header.
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
@ -59,7 +59,7 @@ func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
}
// csrfFromQuery returns a function that extracts token from the query string.
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {