mirror of
https://github.com/gofiber/fiber.git
synced 2025-05-13 03:06:11 +00:00
* feat!(middleware/session): re-write session middleware with handler * test(middleware/session): refactor to IdleTimeout * fix: lint errors * test: Save session after setting or deleting raw data in CSRF middleware * Update middleware/session/middleware.go Co-authored-by: Renan Bastos <renanbastos.tec@gmail.com> * fix: mutex and globals order * feat: Re-Add read lock to session Get method * feat: Migrate New() to return middleware * chore: Refactor session middleware to improve session handling * chore: Private get on store * chore: Update session middleware to use saveSession instead of save * chore: Update session middleware to use getSession instead of get * chore: Remove unused error handler in session middleware config * chore: Update session middleware to use NewWithStore in CSRF tests * test: add test * fix: destroyed session and GHSA-98j2-3j3p-fw2v * chore: Refactor session_test.go to use newStore() instead of New() * feat: Improve session middleware test coverage and error handling This commit improves the session middleware test coverage by adding assertions for the presence of the Set-Cookie header and the token value. It also enhances error handling by checking for the expected number of parts in the Set-Cookie header. * chore: fix lint issues * chore: Fix session middleware locking issue and improve error handling * test: improve middleware test coverage and error handling * test: Add idle timeout test case to session middleware test * feat: add GetSession(id string) (*Session, error) * chore: lint * docs: Update session middleware docs * docs: Security Note to examples * docs: Add recommendation for CSRF protection in session middleware * chore: markdown lint * docs: Update session middleware docs * docs: makrdown lint * test(middleware/session): Add unit tests for session config.go * test(middleware/session): Add unit tests for store.go * test(middleware/session): Add data.go unit tests * refactor(middleware/session): session tests and add session release test - Refactor session tests to improve readability and maintainability. - Add a new test case to ensure proper session release functionality. - Update session.md * refactor: session data locking in middleware/session/data.go * refactor(middleware/session): Add unit test for session middleware store * test: fix session_test.go and store_test.go unit tests * refactor(docs): Update session.md with v3 changes to Expiration * refactor(middleware/session): Improve data pool handling and locking * chore(middleware/session): TODO for Expiration field in session config * refactor(middleware/session): Improve session data pool handling and locking * refactor(middleware/session): Improve session data pool handling and locking * test(middleware/csrf): add session middleware coverage * chroe(middleware/session): TODO for unregistered session middleware * refactor(middleware/session): Update session middleware for v3 changes * refactor(middleware/session): Update session middleware for v3 changes * refactor(middleware/session): Update session middleware idle timeout - Update the default idle timeout for session middleware from 24 hours to 30 minutes. - Add a note in the session middleware documentation about the importance of the middleware order. * docws(middleware/session): Add note about IdleTimeout requiring save using legacy approach * refactor(middleware/session): Update session middleware idle timeout Update the idle timeout for the session middleware to 30 minutes. This ensures that the session expires after a period of inactivity. The previous value was 24 hours, which is too long for most use cases. This change improves the security and efficiency of the session management. * docs(middleware/session): Update session middleware idle timeout and configuration * test(middleware/session): Fix tests for updated panics * refactor(middleware/session): Update session middleware initialization and saving * refactor(middleware/session): Remove unnecessary comment about negative IdleTimeout value * refactor(middleware/session): Update session middleware make NewStore public * refactor(middleware/session): Update session middleware Set, Get, and Delete methods Refactor the Set, Get, and Delete methods in the session middleware to use more descriptive parameter names. Instead of using "middlewareContextKey", the methods now use "key" to represent the key of the session value. This improves the readability and clarity of the code. * feat(middleware/session): AbsoluteTimeout and key any * fix(middleware/session): locking issues and lint errors * chore(middleware/session): Regenerate code in data_msgp.go * refactor(middleware/session): rename GetSessionByID to GetByID This commit also includes changes to the session_test.go and store_test.go files to add test cases for the new GetByID method. * docs(middleware/session): AbsoluteTimeout * refactor(middleware/csrf): Rename Expiration to IdleTimeout * docs(whats-new): CSRF Rename Expiration to IdleTimeout and remove SessionKey field * refactor(middleware/session): Rename expirationKeyType to absExpirationKeyType and update related functions * refactor(middleware/session): rename Test_Session_Save_Absolute to Test_Session_Save_AbsoluteTimeout * chore(middleware/session): update as per PR comments * docs(middlware/session): fix indent lint * fix(middleware/session): Address EfeCtn Comments * refactor(middleware/session): Move bytesBuffer to it's own pool * test(middleware/session): add decodeSessionData error coverage * refactor(middleware/session): Update absolute timeout handling - Update absolute timeout handling in getSession function - Set absolute expiration time in getSession function - Delete expired session in GetByID function * refactor(session/middleware): fix *Session nil ctx when using Store.GetByID * refactor(middleware/session): Remove unnecessary line in session_test.go * fix(middleware/session): *Session lifecycle issues * docs(middleware/session): Update GetByID method documentation * docs(middleware/session): Update GetByID method documentation * docs(middleware/session): markdown lint * refactor(middleware/session): Simplify error handling in DefaultErrorHandler * fix( middleware/session/config.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * add ctx releases for the test cases --------- Co-authored-by: Renan Bastos <renanbastos.tec@gmail.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Co-authored-by: René <rene@gofiber.io>
348 lines
10 KiB
Go
348 lines
10 KiB
Go
package csrf
|
|
|
|
import (
|
|
"errors"
|
|
"net/url"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/utils/v2"
|
|
)
|
|
|
|
var (
|
|
ErrTokenNotFound = errors.New("csrf token not found")
|
|
ErrTokenInvalid = errors.New("csrf token invalid")
|
|
ErrRefererNotFound = errors.New("referer not supplied")
|
|
ErrRefererInvalid = errors.New("referer invalid")
|
|
ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin")
|
|
ErrOriginInvalid = errors.New("origin invalid")
|
|
ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin")
|
|
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
|
|
dummyValue = []byte{'+'}
|
|
)
|
|
|
|
// Handler for CSRF middleware
|
|
type Handler struct {
|
|
sessionManager *sessionManager
|
|
storageManager *storageManager
|
|
config Config
|
|
}
|
|
|
|
// The contextKey type is unexported to prevent collisions with context keys defined in
|
|
// other packages.
|
|
type contextKey int
|
|
|
|
// The keys for the values in context
|
|
const (
|
|
tokenKey contextKey = iota
|
|
handlerKey
|
|
)
|
|
|
|
// New creates a new middleware handler
|
|
func New(config ...Config) fiber.Handler {
|
|
// Set default config
|
|
cfg := configDefault(config...)
|
|
|
|
// Create manager to simplify storage operations ( see *_manager.go )
|
|
var sessionManager *sessionManager
|
|
var storageManager *storageManager
|
|
if cfg.Session != nil {
|
|
sessionManager = newSessionManager(cfg.Session)
|
|
} else {
|
|
storageManager = newStorageManager(cfg.Storage)
|
|
}
|
|
|
|
// Pre-parse trusted origins
|
|
trustedOrigins := []string{}
|
|
trustedSubOrigins := []subdomain{}
|
|
|
|
for _, origin := range cfg.TrustedOrigins {
|
|
if i := strings.Index(origin, "://*."); i != -1 {
|
|
trimmedOrigin := utils.Trim(origin[:i+3]+origin[i+4:], ' ')
|
|
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
|
if !isValid {
|
|
panic("[CSRF] Invalid origin format in configuration:" + origin)
|
|
}
|
|
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
|
|
trustedSubOrigins = append(trustedSubOrigins, sd)
|
|
} else {
|
|
trimmedOrigin := utils.Trim(origin, ' ')
|
|
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
|
if !isValid {
|
|
panic("[CSRF] Invalid origin format in configuration:" + origin)
|
|
}
|
|
trustedOrigins = append(trustedOrigins, normalizedOrigin)
|
|
}
|
|
}
|
|
|
|
// Create the handler outside of the returned function
|
|
handler := &Handler{
|
|
config: cfg,
|
|
sessionManager: sessionManager,
|
|
storageManager: storageManager,
|
|
}
|
|
|
|
// Return new handler
|
|
return func(c fiber.Ctx) error {
|
|
// Don't execute middleware if Next returns true
|
|
if cfg.Next != nil && cfg.Next(c) {
|
|
return c.Next()
|
|
}
|
|
|
|
// Store the CSRF handler in the context
|
|
c.Locals(handlerKey, handler)
|
|
|
|
var token string
|
|
|
|
// Action depends on the HTTP method
|
|
switch c.Method() {
|
|
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
|
|
cookieToken := c.Cookies(cfg.CookieName)
|
|
|
|
if cookieToken != "" {
|
|
raw := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager)
|
|
|
|
if raw != nil {
|
|
token = cookieToken // Token is valid, safe to set it
|
|
}
|
|
}
|
|
default:
|
|
// Assume that anything not defined as 'safe' by RFC7231 needs protection
|
|
|
|
// Enforce an origin check for unsafe requests.
|
|
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)
|
|
|
|
// If there's no origin, enforce a referer check for HTTPS connections.
|
|
if errors.Is(err, errOriginNotFound) {
|
|
if c.Scheme() == "https" {
|
|
err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins)
|
|
} else {
|
|
// If it's not HTTPS, clear the error to allow the request to proceed.
|
|
err = nil
|
|
}
|
|
}
|
|
|
|
// If there's an error (either from origin check or referer check), handle it.
|
|
if err != nil {
|
|
return cfg.ErrorHandler(c, err)
|
|
}
|
|
|
|
// Extract token from client request i.e. header, query, param, form or cookie
|
|
extractedToken, err := cfg.Extractor(c)
|
|
if err != nil {
|
|
return cfg.ErrorHandler(c, err)
|
|
}
|
|
|
|
if extractedToken == "" {
|
|
return cfg.ErrorHandler(c, ErrTokenNotFound)
|
|
}
|
|
|
|
// If not using FromCookie extractor, check that the token matches the cookie
|
|
// This is to prevent CSRF attacks by using a Double Submit Cookie method
|
|
// Useful when we do not have access to the users Session
|
|
if !isFromCookie(cfg.Extractor) && !compareStrings(extractedToken, c.Cookies(cfg.CookieName)) {
|
|
return cfg.ErrorHandler(c, ErrTokenInvalid)
|
|
}
|
|
|
|
raw := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
|
|
|
|
if raw == nil {
|
|
// If token is not in storage, expire the cookie
|
|
expireCSRFCookie(c, cfg)
|
|
// and return an error
|
|
return cfg.ErrorHandler(c, ErrTokenNotFound)
|
|
}
|
|
if cfg.SingleUseToken {
|
|
// If token is single use, delete it from storage
|
|
deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
|
|
} else {
|
|
token = extractedToken // Token is valid, safe to set it
|
|
}
|
|
}
|
|
|
|
// Generate CSRF token if not exist
|
|
if token == "" {
|
|
// And generate a new token
|
|
token = cfg.KeyGenerator()
|
|
}
|
|
|
|
// Create or extend the token in the storage
|
|
createOrExtendTokenInStorage(c, token, cfg, sessionManager, storageManager)
|
|
|
|
// Update the CSRF cookie
|
|
updateCSRFCookie(c, cfg, token)
|
|
|
|
// Tell the browser that a new header value is generated
|
|
c.Vary(fiber.HeaderCookie)
|
|
|
|
// Store the token in the context
|
|
c.Locals(tokenKey, token)
|
|
|
|
// Continue stack
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
// TokenFromContext returns the token found in the context
|
|
// returns an empty string if the token does not exist
|
|
func TokenFromContext(c fiber.Ctx) string {
|
|
token, ok := c.Locals(tokenKey).(string)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return token
|
|
}
|
|
|
|
// HandlerFromContext returns the Handler found in the context
|
|
// returns nil if the handler does not exist
|
|
func HandlerFromContext(c fiber.Ctx) *Handler {
|
|
handler, ok := c.Locals(handlerKey).(*Handler)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return handler
|
|
}
|
|
|
|
// getRawFromStorage returns the raw value from the storage for the given token
|
|
// returns nil if the token does not exist, is expired or is invalid
|
|
func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte {
|
|
if cfg.Session != nil {
|
|
return sessionManager.getRaw(c, token, dummyValue)
|
|
}
|
|
return storageManager.getRaw(token)
|
|
}
|
|
|
|
// createOrExtendTokenInStorage creates or extends the token in the storage
|
|
func createOrExtendTokenInStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
|
|
if cfg.Session != nil {
|
|
sessionManager.setRaw(c, token, dummyValue, cfg.IdleTimeout)
|
|
} else {
|
|
storageManager.setRaw(token, dummyValue, cfg.IdleTimeout)
|
|
}
|
|
}
|
|
|
|
func deleteTokenFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
|
|
if cfg.Session != nil {
|
|
sessionManager.delRaw(c)
|
|
} else {
|
|
storageManager.delRaw(token)
|
|
}
|
|
}
|
|
|
|
// Update CSRF cookie
|
|
// if expireCookie is true, the cookie will expire immediately
|
|
func updateCSRFCookie(c fiber.Ctx, cfg Config, token string) {
|
|
setCSRFCookie(c, cfg, token, cfg.IdleTimeout)
|
|
}
|
|
|
|
func expireCSRFCookie(c fiber.Ctx, cfg Config) {
|
|
setCSRFCookie(c, cfg, "", -time.Hour)
|
|
}
|
|
|
|
func setCSRFCookie(c fiber.Ctx, cfg Config, token string, expiry time.Duration) {
|
|
cookie := &fiber.Cookie{
|
|
Name: cfg.CookieName,
|
|
Value: token,
|
|
Domain: cfg.CookieDomain,
|
|
Path: cfg.CookiePath,
|
|
Secure: cfg.CookieSecure,
|
|
HTTPOnly: cfg.CookieHTTPOnly,
|
|
SameSite: cfg.CookieSameSite,
|
|
SessionOnly: cfg.CookieSessionOnly,
|
|
Expires: time.Now().Add(expiry),
|
|
}
|
|
|
|
// Set the CSRF cookie to the response
|
|
c.Cookie(cookie)
|
|
}
|
|
|
|
// DeleteToken removes the token found in the context from the storage
|
|
// and expires the CSRF cookie
|
|
func (handler *Handler) DeleteToken(c fiber.Ctx) error {
|
|
// Extract token from the client request cookie
|
|
cookieToken := c.Cookies(handler.config.CookieName)
|
|
if cookieToken == "" {
|
|
return handler.config.ErrorHandler(c, ErrTokenNotFound)
|
|
}
|
|
// Remove the token from storage
|
|
deleteTokenFromStorage(c, cookieToken, handler.config, handler.sessionManager, handler.storageManager)
|
|
// Expire the cookie
|
|
expireCSRFCookie(c, handler.config)
|
|
return nil
|
|
}
|
|
|
|
// isFromCookie checks if the extractor is set to ExtractFromCookie
|
|
func isFromCookie(extractor any) bool {
|
|
return reflect.ValueOf(extractor).Pointer() == reflect.ValueOf(FromCookie).Pointer()
|
|
}
|
|
|
|
// originMatchesHost checks that the origin header matches the host header
|
|
// returns an error if the origin header is not present or is invalid
|
|
// returns nil if the origin header is valid
|
|
func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
|
|
origin := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
|
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
|
|
return errOriginNotFound
|
|
}
|
|
|
|
originURL, err := url.Parse(origin)
|
|
if err != nil {
|
|
return ErrOriginInvalid
|
|
}
|
|
|
|
if originURL.Scheme == c.Scheme() && originURL.Host == c.Host() {
|
|
return nil
|
|
}
|
|
|
|
for _, trustedOrigin := range trustedOrigins {
|
|
if origin == trustedOrigin {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
for _, trustedSubOrigin := range trustedSubOrigins {
|
|
if trustedSubOrigin.match(origin) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return ErrOriginNoMatch
|
|
}
|
|
|
|
// refererMatchesHost checks that the referer header matches the host header
|
|
// returns an error if the referer header is not present or is invalid
|
|
// returns nil if the referer header is valid
|
|
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
|
|
referer := strings.ToLower(c.Get(fiber.HeaderReferer))
|
|
if referer == "" {
|
|
return ErrRefererNotFound
|
|
}
|
|
|
|
refererURL, err := url.Parse(referer)
|
|
if err != nil {
|
|
return ErrRefererInvalid
|
|
}
|
|
|
|
if refererURL.Scheme == c.Scheme() && refererURL.Host == c.Host() {
|
|
return nil
|
|
}
|
|
|
|
referer = refererURL.String()
|
|
|
|
for _, trustedOrigin := range trustedOrigins {
|
|
if referer == trustedOrigin {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
for _, trustedSubOrigin := range trustedSubOrigins {
|
|
if trustedSubOrigin.match(referer) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return ErrRefererNoMatch
|
|
}
|