From 598b9ff12dcd60829571cae015cbefa88d2a01a2 Mon Sep 17 00:00:00 2001 From: JIeJaitt <498938874@qq.com> Date: Wed, 2 Apr 2025 17:26:00 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20Feature:=20Add=20CSRF=20support?= =?UTF-8?q?=20for=20proxy=20middleware?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/csrf/csrf.go | 31 +++++++++++---- middleware/csrf/csrf_test.go | 76 ++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index dedfe6bd..c2484d30 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -86,6 +86,10 @@ func New(config ...Config) fiber.Handler { // Return new handler return func(c fiber.Ctx) error { + var ( + err error + ) + // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() @@ -171,17 +175,28 @@ func New(config ...Config) fiber.Handler { // Create or extend the token in the storage createOrExtendTokenInStorage(c, token, cfg, sessionManager, storageManager) - // Update the CSRF cookie - updateCSRFCookie(c, cfg, token) - - // Tell the browser that a new header value is generated - c.Vary(fiber.HeaderCookie) - - // Store the token in the context + // Store the token in the context (Keep this BEFORE c.Next()) c.Locals(tokenKey, token) + // Execute the next middleware or handler in the stack. + err = c.Next() + + // Retrieve the final token from the context, if it was set. + finalToken, ok := c.Locals(tokenKey).(string) + // Check if the token exists and is not empty. + if ok && finalToken != "" { // Ensure token exists + // Update the CSRF cookie in the response with the final token. + updateCSRFCookie(c, cfg, finalToken) + + // Add the Vary: Cookie header to indicate that the response may differ + // based on the Cookie header, which is important for caching mechanisms. + // Tell the browser that a new header value is generated + c.Vary(fiber.HeaderCookie) + } + + // Return any error that occurred during the execution of the next handlers. // Continue stack - return c.Next() + return err } } diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 090082f4..b1bd603a 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -1,12 +1,14 @@ package csrf import ( + "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/proxy" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" @@ -1564,3 +1566,77 @@ func Test_CSRF_FromContextMethods_Invalid(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } + +// Test_CSRF_With_Proxy_Middleware verifies that the CSRF cookie is set +// even when the request is handled by the proxy middleware, addressing +// issue https://github.com/gofiber/fiber/issues/3387 +func Test_CSRF_With_Proxy_Middleware(t *testing.T) { + t.Parallel() + + // 1. Create a target server that the proxy will forward to + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("Hello from target")) + })) + defer targetServer.Close() + + // 2. Create the main Fiber app (proxy server) + app := fiber.New() + + // 3. Setup Session middleware (required by CSRF config in this test) + // sessionMiddleware, sessionStore := session.NewWithStore() + // app.Use(sessionMiddleware) + _, sessionStore := session.NewWithStore() + app.Use(session.New(session.Config{ + Store: sessionStore, + })) + + // 4. Setup CSRF middleware, configured to use the session store + // Ensure this runs AFTER the session middleware but BEFORE the proxy handler + app.Use(New(Config{ + Session: sessionStore, + })) + + // 5. Define a route that proxies requests to the target server + app.Get("/proxy", func(c fiber.Ctx) error { + // The proxy middleware will execute, potentially overwriting headers initially. + // The CSRF middleware should fix this after c.Next() (which proxy.Do calls internally) returns. + return proxy.Do(c, targetServer.URL) + }) + + // 6. Make a GET request to the proxy route + req := httptest.NewRequest(fiber.MethodGet, "/proxy", nil) + resp, err := app.Test(req) + + // 7. Assertions + require.NoError(t, err, "app.Test failed") + require.Equal(t, http.StatusOK, resp.StatusCode, "Expected status OK from target server via proxy") + + // --- CRUCIAL CHECK --- + // Verify that the CSRF cookie was set in the *final* response headers + // This confirms the CSRF middleware correctly set the cookie *after* the proxy logic. + foundCSRFCookie := false + csrfCookieName := ConfigDefault.CookieName // Use the default name or the one from config + for _, cookie := range resp.Cookies() { + if cookie.Name == csrfCookieName { + foundCSRFCookie = true + require.NotEmpty(t, cookie.Value, "CSRF cookie value should not be empty") + t.Logf("Found CSRF cookie: %s=%s", cookie.Name, cookie.Value) // Optional logging + break + } + } + require.True(t, foundCSRFCookie, "CSRF cookie '%s' not found in response headers after proxy", csrfCookieName) + + // Optional: Verify session cookie is also present + foundSessionCookie := false + sessionCookieName := "session_id" // Default name for session.NewWithStore() unless configured otherwise + for _, cookie := range resp.Cookies() { + if cookie.Name == sessionCookieName { + foundSessionCookie = true + require.NotEmpty(t, cookie.Value, "Session cookie value should not be empty") + t.Logf("Found Session cookie: %s=%s", cookie.Name, cookie.Value) // Optional logging + break + } + } + require.True(t, foundSessionCookie, "Session cookie '%s' not found in response headers after proxy", sessionCookieName) +}