Jason McNeil 53e5dc523e
🩹 Fix: CSRF middleware cookie<>storage bug squashed and other improvements (#1180)
* expire cookie on Post, Delete, Patch and Put

Cookie should always expire on Post, Delete, Patch and Put as it is either valid and will be removed from storage, or is not in storage and invalid

* token and cookie match

* retrigger checks

* csrf tests

* csrf per session strategy
2021-03-01 17:44:17 +01:00

181 lines
4.7 KiB
Go

package csrf
import (
"errors"
"net/textproto"
"strings"
"time"
"github.com/gofiber/fiber/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}
// By default we extract from a header
extractor := csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
switch selectors[0] {
case "form":
extractor = csrfFromForm(selectors[1])
case "query":
extractor = csrfFromQuery(selectors[1])
case "param":
extractor = csrfFromParam(selectors[1])
case "cookie":
extractor = csrfFromCookie(selectors[1])
}
dummyValue := []byte{'+'}
// Return new handler
return func(c *fiber.Ctx) (err error) {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
var token string
// Action depends on the HTTP method
switch c.Method() {
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
// Declare empty token and try to get existing CSRF from cookie
token = c.Cookies(cfg.CookieName)
default:
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Extract token from client request i.e. header, query, param, form or cookie
token, err = extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
// if token does not exist in Storage
if manager.getRaw(token) == nil {
// Expire cookie
c.Cookie(&fiber.Cookie{
Name: cfg.CookieName,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Expires: time.Now().Add(-1 * time.Minute),
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
})
return cfg.ErrorHandler(c, err)
}
}
// Generate CSRF token if not exist
if token == "" {
// And generate a new token
token = cfg.KeyGenerator()
}
// Add/update token to Storage
manager.setRaw(token, dummyValue, cfg.Expiration)
// Create cookie to pass token to client
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Expires: time.Now().Add(cfg.Expiration),
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
}
// Set cookie to response
c.Cookie(cookie)
// Protect clients from caching the response by telling the browser
// a new header value is generated
c.Vary(fiber.HeaderCookie)
// Store token in context if set
if cfg.ContextKey != "" {
c.Locals(cfg.ContextKey, token)
}
// Continue stack
return c.Next()
}
}
var (
errMissingHeader = errors.New("missing csrf token in header")
errMissingQuery = errors.New("missing csrf token in query")
errMissingParam = errors.New("missing csrf token in param")
errMissingForm = errors.New("missing csrf token in form")
errMissingCookie = errors.New("missing csrf token in cookie")
)
// csrfFromHeader returns a function that extracts token from the request header.
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
return "", errMissingHeader
}
return token, nil
}
}
// csrfFromQuery returns a function that extracts token from the query string.
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
return "", errMissingQuery
}
return token, nil
}
}
// csrfFromParam returns a function that extracts token from the url param string.
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
return "", errMissingParam
}
return token, nil
}
}
// csrfFromForm returns a function that extracts a token from a multipart-form.
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errMissingForm
}
return token, nil
}
}
// csrfFromCookie returns a function that extracts token from the cookie header.
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
return "", errMissingCookie
}
return token, nil
}
}