mirror of https://github.com/gofiber/fiber.git
272 lines
6.7 KiB
Go
272 lines
6.7 KiB
Go
package csrf
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/utils"
|
|
)
|
|
|
|
// Config defines the config for middleware.
|
|
type Config struct {
|
|
// Next defines a function to skip this middleware when returned true.
|
|
//
|
|
// Optional. Default: nil
|
|
Next func(c *fiber.Ctx) bool
|
|
|
|
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
|
// to extract token from the request.
|
|
//
|
|
// Optional. Default value "header:X-CSRF-Token".
|
|
// Possible values:
|
|
// - "header:<name>"
|
|
// - "query:<name>"
|
|
// - "param:<name>"
|
|
// - "form:<name>"
|
|
// - "cookie:<name>"
|
|
TokenLookup string
|
|
|
|
// Cookie
|
|
//
|
|
// Optional.
|
|
Cookie *fiber.Cookie
|
|
|
|
// Deprecated, please use Expiration
|
|
CookieExpires time.Duration
|
|
|
|
// Expiration is the duration before csrf token will expire
|
|
//
|
|
// Optional. Default: 24 * time.Hour
|
|
Expiration time.Duration
|
|
|
|
// Context key to store generated CSRF token into context.
|
|
//
|
|
// Optional. Default value "csrf".
|
|
ContextKey string
|
|
}
|
|
|
|
// ConfigDefault is the default config
|
|
var ConfigDefault = Config{
|
|
Next: nil,
|
|
TokenLookup: "header:X-CSRF-Token",
|
|
ContextKey: "csrf",
|
|
Cookie: &fiber.Cookie{
|
|
Name: "_csrf",
|
|
SameSite: "Strict",
|
|
},
|
|
Expiration: 24 * time.Hour,
|
|
CookieExpires: 24 * time.Hour, // deprecated
|
|
}
|
|
|
|
type storage struct {
|
|
sync.RWMutex
|
|
tokens map[string]int64
|
|
}
|
|
|
|
// New creates a new middleware handler
|
|
func New(config ...Config) fiber.Handler {
|
|
// Set default config
|
|
cfg := ConfigDefault
|
|
|
|
// Override config if provided
|
|
if len(config) > 0 {
|
|
cfg = config[0]
|
|
|
|
// Set default values
|
|
if cfg.TokenLookup == "" {
|
|
cfg.TokenLookup = ConfigDefault.TokenLookup
|
|
}
|
|
if cfg.ContextKey == "" {
|
|
cfg.ContextKey = ConfigDefault.ContextKey
|
|
}
|
|
if cfg.CookieExpires != 0 {
|
|
fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration")
|
|
cfg.CookieExpires = ConfigDefault.Expiration
|
|
}
|
|
if cfg.Expiration == 0 {
|
|
cfg.Expiration = ConfigDefault.Expiration
|
|
}
|
|
if cfg.Cookie != nil {
|
|
if cfg.Cookie.Name == "" {
|
|
cfg.Cookie.Name = ConfigDefault.Cookie.Name
|
|
}
|
|
if cfg.Cookie.SameSite == "" {
|
|
cfg.Cookie.SameSite = ConfigDefault.Cookie.SameSite
|
|
}
|
|
} else {
|
|
cfg.Cookie = ConfigDefault.Cookie
|
|
}
|
|
}
|
|
expiration := int64(cfg.Expiration.Seconds())
|
|
|
|
// Generate the correct extractor to get the token from the correct location
|
|
selectors := strings.Split(cfg.TokenLookup, ":")
|
|
|
|
if len(selectors) != 2 {
|
|
panic("csrf: Token lookup must in the form of <source>:<key>")
|
|
}
|
|
|
|
// By default we extract from a header
|
|
extractor := csrfFromHeader(selectors[1])
|
|
|
|
switch selectors[0] {
|
|
case "form":
|
|
extractor = csrfFromForm(selectors[1])
|
|
case "query":
|
|
extractor = csrfFromQuery(selectors[1])
|
|
case "param":
|
|
extractor = csrfFromParam(selectors[1])
|
|
case "cookie":
|
|
if selectors[1] == cfg.Cookie.Name {
|
|
panic(fmt.Sprintf("TokenLookup key %s can't be the same as Cookie.Name %s", selectors[1], cfg.Cookie.Name))
|
|
}
|
|
extractor = csrfFromCookie(selectors[1])
|
|
}
|
|
|
|
// create new db
|
|
db := storage{
|
|
tokens: make(map[string]int64),
|
|
}
|
|
// Remove expired entries
|
|
go func() {
|
|
for {
|
|
// GC the tokens every 10 seconds to avoid
|
|
time.Sleep(10 * time.Second)
|
|
db.Lock()
|
|
for t := range db.tokens {
|
|
if time.Now().Unix() >= db.tokens[t] {
|
|
delete(db.tokens, t)
|
|
}
|
|
}
|
|
db.Unlock()
|
|
}
|
|
}()
|
|
|
|
// Return new handler
|
|
return func(c *fiber.Ctx) error {
|
|
// Don't execute middleware if Next returns true
|
|
if cfg.Next != nil && cfg.Next(c) {
|
|
return c.Next()
|
|
}
|
|
|
|
// Declare empty token and try to get previous generated CSRF from cookie
|
|
token, key := "", c.Cookies(cfg.Cookie.Name)
|
|
|
|
// Check if the cookie had a CSRF token
|
|
if key == "" {
|
|
// Create a new CSRF token
|
|
token = utils.UUID()
|
|
// Add token with timestamp expiration
|
|
db.Lock()
|
|
db.tokens[token] = int64(time.Now().Unix()) + expiration
|
|
db.Unlock()
|
|
} else {
|
|
// Use the server generated token previously to compare
|
|
// To the extracted token later on
|
|
token = key
|
|
}
|
|
|
|
// Verify CSRF token on POST requests
|
|
if c.Method() == fiber.MethodPost {
|
|
// Extract token from client request i.e. header, query, param or form
|
|
csrf, err := extractor(c)
|
|
if err != nil {
|
|
// We have a problem extracting the csrf token
|
|
return fiber.ErrForbidden
|
|
}
|
|
|
|
// Get token from DB
|
|
db.RLock()
|
|
t, ok := db.tokens[csrf]
|
|
db.RUnlock()
|
|
// Check if token exist or expired
|
|
if !ok || time.Now().Unix() >= t {
|
|
return fiber.ErrForbidden
|
|
}
|
|
}
|
|
|
|
// Create new cookie to send new CSRF token
|
|
cookie := &fiber.Cookie{
|
|
Name: cfg.Cookie.Name,
|
|
Value: token,
|
|
Domain: cfg.Cookie.Domain,
|
|
Path: cfg.Cookie.Path,
|
|
Expires: time.Now().Add(cfg.CookieExpires),
|
|
Secure: cfg.Cookie.Secure,
|
|
HTTPOnly: cfg.Cookie.HTTPOnly,
|
|
SameSite: cfg.Cookie.SameSite,
|
|
}
|
|
|
|
// Set cookie to response
|
|
c.Cookie(cookie)
|
|
// Store token in context
|
|
c.Locals(cfg.ContextKey, token)
|
|
|
|
// Protect clients from caching the response by telling the browser
|
|
// a new header value is generated
|
|
c.Vary(fiber.HeaderCookie)
|
|
|
|
// Continue stack
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
// csrfFromHeader returns a function that extracts token from the request header.
|
|
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
|
|
return func(c *fiber.Ctx) (string, error) {
|
|
token := c.Get(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in header")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
// csrfFromQuery returns a function that extracts token from the query string.
|
|
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
|
|
return func(c *fiber.Ctx) (string, error) {
|
|
token := c.Query(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in query string")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
// csrfFromParam returns a function that extracts token from the url param string.
|
|
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
|
return func(c *fiber.Ctx) (string, error) {
|
|
token := c.Params(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in url parameter")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
// csrfFromParam returns a function that extracts token from the url param string.
|
|
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
|
|
return func(c *fiber.Ctx) (string, error) {
|
|
token := c.FormValue(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in form parameter")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
// csrfFromCookie returns a function that extracts token from the cookie header.
|
|
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
|
|
return func(c *fiber.Ctx) (string, error) {
|
|
token := c.Cookies(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in cookie")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|