diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index faea6752..4a56b388 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -49,29 +49,29 @@ func New(config ...Config) fiber.Handler { dummyVal := []byte{'+'} // Return new handler - return func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) (err error) { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } + var token string + // Action depends on the HTTP method switch c.Method() { case fiber.MethodGet: // Generate CSRF token if not exist // Declare empty token and try to get existing CSRF from cookie - token, key := "", c.Cookies(cfg.Cookie.Name) + token = c.Cookies(cfg.Cookie.Name) // Do we have an existing CSRF token? - if key != "" { - token = key - } else { + if token == "" { // Generate new CSRF token token = cfg.KeyGenerator() // Add token to Storage - if err := cfg.Storage.Set(token, dummyVal, cfg.Expiration); err != nil { + if err = cfg.Storage.Set(token, dummyVal, cfg.Expiration); err != nil { fmt.Println("[CSRF]", err.Error()) } } @@ -90,27 +90,17 @@ func New(config ...Config) fiber.Handler { // Set cookie to response c.Cookie(cookie) - - // Protect clients from caching the response by telling the browser - // a new header value is generated - c.Vary(fiber.HeaderCookie) - - // Store token in context if set - if cfg.ContextKey != "" { - c.Locals(cfg.ContextKey, token) - } - case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut: // Verify CSRF token // Extract token from client request i.e. header, query, param, form or cookie - csrf, err := extractor(c) + token, err = extractor(c) if err != nil { return fiber.ErrForbidden } // We have a problem extracting the csrf token from Storage - if _, err = cfg.Storage.Get(csrf); err != nil { + if _, err = cfg.Storage.Get(token); err != nil { // The token is invalid, let client generate a new one - if err = cfg.Storage.Delete(csrf); err != nil { + if err = cfg.Storage.Delete(token); err != nil { fmt.Println("[CSRF]", err.Error()) } // Expire cookie @@ -127,6 +117,15 @@ func New(config ...Config) fiber.Handler { } } + // Protect clients from caching the response by telling the browser + // a new header value is generated + c.Vary(fiber.HeaderCookie) + + // Store token in context if set + if cfg.ContextKey != "" { + c.Locals(cfg.ContextKey, token) + } + // Continue stack return c.Next() }