mirror of https://github.com/gofiber/fiber.git
🚀 [Feature]: middleware/csrf custom extractor (#2052)
* feat(middleware/csrf): allow custom Extractor * test: update Test_CSRF_From_Custom * docs: add comma * docs: update KeyLookup docspull/2058/head
parent
506f0b21c5
commit
6272d759eb
|
@ -49,9 +49,12 @@ app.Use(csrf.New(csrf.Config{
|
|||
CookieSameSite: "Lax",
|
||||
Expiration: 1 * time.Hour,
|
||||
KeyGenerator: utils.UUID,
|
||||
Extractor: func(c *fiber.Ctx) (string, error) { ... },
|
||||
}))
|
||||
```
|
||||
|
||||
Note: KeyLookup will be ignored if Extractor is explicitly set.
|
||||
|
||||
### Custom Storage/Database
|
||||
|
||||
You can use any storage from our [storage](https://github.com/gofiber/storage/) package.
|
||||
|
@ -74,7 +77,7 @@ type Config struct {
|
|||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to extract token from the request.
|
||||
// to create an Extractor that extracts the token from the request.
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
|
@ -82,6 +85,8 @@ type Config struct {
|
|||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
//
|
||||
// Ignored if an Extractor is explicitly set.
|
||||
//
|
||||
// Optional. Default: "header:X-CSRF-Token"
|
||||
KeyLookup string
|
||||
|
||||
|
@ -133,6 +138,13 @@ type Config struct {
|
|||
//
|
||||
// Optional. Default: utils.UUID
|
||||
KeyGenerator func() string
|
||||
|
||||
// Extractor returns the csrf token
|
||||
//
|
||||
// If set this will be used in place of an Extractor based on KeyLookup.
|
||||
//
|
||||
// Optional. Default will create an Extractor based on KeyLookup.
|
||||
Extractor func(c *fiber.Ctx) (string, error)
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ type Config struct {
|
|||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to extract token from the request.
|
||||
// to create an Extractor that extracts the token from the request.
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
|
@ -26,6 +26,8 @@ type Config struct {
|
|||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
//
|
||||
// Ignored if an Extractor is explicitly set.
|
||||
//
|
||||
// Optional. Default: "header:X-CSRF-Token"
|
||||
KeyLookup string
|
||||
|
||||
|
@ -92,8 +94,12 @@ type Config struct {
|
|||
// Optional. Default: DefaultErrorHandler
|
||||
ErrorHandler fiber.ErrorHandler
|
||||
|
||||
// extractor returns the csrf token from the request based on KeyLookup
|
||||
extractor func(c *fiber.Ctx) (string, error)
|
||||
// Extractor returns the csrf token
|
||||
//
|
||||
// If set this will be used in place of an Extractor based on KeyLookup.
|
||||
//
|
||||
// Optional. Default will create an Extractor based on KeyLookup.
|
||||
Extractor func(c *fiber.Ctx) (string, error)
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
|
@ -104,7 +110,7 @@ var ConfigDefault = Config{
|
|||
Expiration: 1 * time.Hour,
|
||||
KeyGenerator: utils.UUID,
|
||||
ErrorHandler: defaultErrorHandler,
|
||||
extractor: csrfFromHeader("X-Csrf-Token"),
|
||||
Extractor: CsrfFromHeader("X-Csrf-Token"),
|
||||
}
|
||||
|
||||
// default ErrorHandler that process return error from fiber.Handler
|
||||
|
@ -174,18 +180,20 @@ func configDefault(config ...Config) Config {
|
|||
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
|
||||
}
|
||||
|
||||
// By default we extract from a header
|
||||
cfg.extractor = csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
|
||||
if cfg.Extractor == nil {
|
||||
// By default we extract from a header
|
||||
cfg.Extractor = CsrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
|
||||
|
||||
switch selectors[0] {
|
||||
case "form":
|
||||
cfg.extractor = csrfFromForm(selectors[1])
|
||||
case "query":
|
||||
cfg.extractor = csrfFromQuery(selectors[1])
|
||||
case "param":
|
||||
cfg.extractor = csrfFromParam(selectors[1])
|
||||
case "cookie":
|
||||
cfg.extractor = csrfFromCookie(selectors[1])
|
||||
switch selectors[0] {
|
||||
case "form":
|
||||
cfg.Extractor = CsrfFromForm(selectors[1])
|
||||
case "query":
|
||||
cfg.Extractor = CsrfFromQuery(selectors[1])
|
||||
case "param":
|
||||
cfg.Extractor = CsrfFromParam(selectors[1])
|
||||
case "cookie":
|
||||
cfg.Extractor = CsrfFromCookie(selectors[1])
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
|
|
|
@ -39,7 +39,7 @@ func New(config ...Config) fiber.Handler {
|
|||
// Assume that anything not defined as 'safe' by RFC7231 needs protection
|
||||
|
||||
// Extract token from client request i.e. header, query, param, form or cookie
|
||||
token, err = cfg.extractor(c)
|
||||
token, err = cfg.Extractor(c)
|
||||
if err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
|
|
|
@ -236,6 +236,50 @@ func Test_CSRF_From_Cookie(t *testing.T) {
|
|||
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
func Test_CSRF_From_Custom(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
extractor := func(c *fiber.Ctx) (string, error) {
|
||||
body := string(c.Body())
|
||||
// Generate the correct extractor to get the token from the correct location
|
||||
selectors := strings.Split(body, "=")
|
||||
|
||||
if len(selectors) != 2 || selectors[1] == "" {
|
||||
return "", errMissingParam
|
||||
}
|
||||
return selectors[1], nil
|
||||
}
|
||||
|
||||
app.Use(New(Config{Extractor: extractor}))
|
||||
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Invalid CSRF token
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
||||
ctx.Request.SetBodyString("_csrf=" + token)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ var (
|
|||
)
|
||||
|
||||
// csrfFromParam returns a function that extracts token from the url param string.
|
||||
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
||||
func CsrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Params(param)
|
||||
if token == "" {
|
||||
|
@ -26,7 +26,7 @@ func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// csrfFromForm returns a function that extracts a token from a multipart-form.
|
||||
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
|
||||
func CsrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.FormValue(param)
|
||||
if token == "" {
|
||||
|
@ -37,7 +37,7 @@ func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// csrfFromCookie returns a function that extracts token from the cookie header.
|
||||
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
|
||||
func CsrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Cookies(param)
|
||||
if token == "" {
|
||||
|
@ -48,7 +48,7 @@ func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// csrfFromHeader returns a function that extracts token from the request header.
|
||||
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
|
||||
func CsrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Get(param)
|
||||
if token == "" {
|
||||
|
@ -59,7 +59,7 @@ func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// csrfFromQuery returns a function that extracts token from the query string.
|
||||
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
|
||||
func CsrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Query(param)
|
||||
if token == "" {
|
||||
|
|
Loading…
Reference in New Issue