🔥 Feature: Add CSRF support for proxy middleware

pull/3390/head
JIeJaitt 2025-04-02 17:26:00 +08:00
parent b0bc32b534
commit 598b9ff12d
2 changed files with 99 additions and 8 deletions

View File

@ -86,6 +86,10 @@ func New(config ...Config) fiber.Handler {
// Return new handler
return func(c fiber.Ctx) error {
var (
err error
)
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
@ -171,17 +175,28 @@ func New(config ...Config) fiber.Handler {
// Create or extend the token in the storage
createOrExtendTokenInStorage(c, token, cfg, sessionManager, storageManager)
// Update the CSRF cookie
updateCSRFCookie(c, cfg, token)
// Tell the browser that a new header value is generated
c.Vary(fiber.HeaderCookie)
// Store the token in the context
// Store the token in the context (Keep this BEFORE c.Next())
c.Locals(tokenKey, token)
// Execute the next middleware or handler in the stack.
err = c.Next()
// Retrieve the final token from the context, if it was set.
finalToken, ok := c.Locals(tokenKey).(string)
// Check if the token exists and is not empty.
if ok && finalToken != "" { // Ensure token exists
// Update the CSRF cookie in the response with the final token.
updateCSRFCookie(c, cfg, finalToken)
// Add the Vary: Cookie header to indicate that the response may differ
// based on the Cookie header, which is important for caching mechanisms.
// Tell the browser that a new header value is generated
c.Vary(fiber.HeaderCookie)
}
// Return any error that occurred during the execution of the next handlers.
// Continue stack
return c.Next()
return err
}
}

View File

@ -1,12 +1,14 @@
package csrf
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/proxy"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
@ -1564,3 +1566,77 @@ func Test_CSRF_FromContextMethods_Invalid(t *testing.T) {
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// Test_CSRF_With_Proxy_Middleware verifies that the CSRF cookie is set
// even when the request is handled by the proxy middleware, addressing
// issue https://github.com/gofiber/fiber/issues/3387
func Test_CSRF_With_Proxy_Middleware(t *testing.T) {
t.Parallel()
// 1. Create a target server that the proxy will forward to
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Hello from target"))
}))
defer targetServer.Close()
// 2. Create the main Fiber app (proxy server)
app := fiber.New()
// 3. Setup Session middleware (required by CSRF config in this test)
// sessionMiddleware, sessionStore := session.NewWithStore()
// app.Use(sessionMiddleware)
_, sessionStore := session.NewWithStore()
app.Use(session.New(session.Config{
Store: sessionStore,
}))
// 4. Setup CSRF middleware, configured to use the session store
// Ensure this runs AFTER the session middleware but BEFORE the proxy handler
app.Use(New(Config{
Session: sessionStore,
}))
// 5. Define a route that proxies requests to the target server
app.Get("/proxy", func(c fiber.Ctx) error {
// The proxy middleware will execute, potentially overwriting headers initially.
// The CSRF middleware should fix this after c.Next() (which proxy.Do calls internally) returns.
return proxy.Do(c, targetServer.URL)
})
// 6. Make a GET request to the proxy route
req := httptest.NewRequest(fiber.MethodGet, "/proxy", nil)
resp, err := app.Test(req)
// 7. Assertions
require.NoError(t, err, "app.Test failed")
require.Equal(t, http.StatusOK, resp.StatusCode, "Expected status OK from target server via proxy")
// --- CRUCIAL CHECK ---
// Verify that the CSRF cookie was set in the *final* response headers
// This confirms the CSRF middleware correctly set the cookie *after* the proxy logic.
foundCSRFCookie := false
csrfCookieName := ConfigDefault.CookieName // Use the default name or the one from config
for _, cookie := range resp.Cookies() {
if cookie.Name == csrfCookieName {
foundCSRFCookie = true
require.NotEmpty(t, cookie.Value, "CSRF cookie value should not be empty")
t.Logf("Found CSRF cookie: %s=%s", cookie.Name, cookie.Value) // Optional logging
break
}
}
require.True(t, foundCSRFCookie, "CSRF cookie '%s' not found in response headers after proxy", csrfCookieName)
// Optional: Verify session cookie is also present
foundSessionCookie := false
sessionCookieName := "session_id" // Default name for session.NewWithStore() unless configured otherwise
for _, cookie := range resp.Cookies() {
if cookie.Name == sessionCookieName {
foundSessionCookie = true
require.NotEmpty(t, cookie.Value, "Session cookie value should not be empty")
t.Logf("Found Session cookie: %s=%s", cookie.Name, cookie.Value) // Optional logging
break
}
}
require.True(t, foundSessionCookie, "Session cookie '%s' not found in response headers after proxy", sessionCookieName)
}