diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index aab36685..53a8900e 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -178,6 +178,11 @@ func New(config ...Config) fiber.Handler { // Execute the next middleware or handler in the stack. err = c.Next() + // If the next handler returned an error, return it immediately. + // Do not proceed to update the CSRF cookie if the request failed downstream. + if err != nil { + return err + } // Retrieve the final token from the context, if it was set. finalToken, ok := c.Locals(tokenKey).(string) diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 3ce71115..78b0d7df 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -1652,3 +1652,81 @@ func Test_CSRF_With_Proxy_Middleware(t *testing.T) { } require.True(t, foundSessionCookie, "Session cookie '%s' not found in response headers after proxy", sessionCookieName) } + +// Test_CSRF_Custom_CookieName_With_Proxy verifies that a custom CSRF cookie name +// is correctly set in the response even when the request goes through the proxy middleware. +func Test_CSRF_Custom_CookieName_With_Proxy(t *testing.T) { + t.Parallel() + + // 1. Define the custom cookie name to be tested + customCookieName := "my_special_csrf_cookie_for_proxy" + sessionCookieName := "session_id" + + // 2. Create a target server that the proxy will forward to + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Target server simply responds OK + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("Hello from target for custom cookie test")) + if err != nil { + t.Fatal(err) + } + })) + defer targetServer.Close() + + // 3. Create the main Fiber app (proxy server) + app := fiber.New() + + // 4. Setup Session middleware (often used with CSRF, ensures state handling) + // Use a specific store for this test to avoid conflicts if run in parallel with others + sessionStore := session.NewStore() + app.Use(session.New(session.Config{ + Store: sessionStore, + })) + + // 5. Setup CSRF middleware with the custom cookie name and session store + // This is the core part: configuring CSRF with the custom name. + app.Use(New(Config{ + CookieName: customCookieName, // Set the custom cookie name here + Session: sessionStore, // Link CSRF state to the session + })) + + // 6. Define a route that proxies requests to the target server + app.Get("/proxy-custom-cookie", func(c fiber.Ctx) error { + // The request will pass through session and CSRF middleware before hitting the proxy. + // The CSRF middleware should set the cookie *after* proxy.Do completes. + return proxy.Do(c, targetServer.URL) + }) + + // 7. Make a GET request to the proxy route + req := httptest.NewRequest(fiber.MethodGet, "/proxy-custom-cookie", nil) + resp, err := app.Test(req) + + // 8. Assertions + require.NoError(t, err, "app.Test failed for proxy with custom cookie name") + require.Equal(t, http.StatusOK, resp.StatusCode, "Expected status OK from target server via proxy") + + // --- CRUCIAL CHECK --- + // Verify that the CSRF cookie with the *custom name* was set in the final response headers. + foundCustomCSRFCookie := false + for _, cookie := range resp.Cookies() { + if cookie.Name == customCookieName { // Check specifically for the custom name + foundCustomCSRFCookie = true + require.NotEmpty(t, cookie.Value, "Custom CSRF cookie value should not be empty") + t.Logf("Found custom CSRF cookie via proxy: %s=%s", cookie.Name, cookie.Value) + break + } + } + require.True(t, foundCustomCSRFCookie, "Custom CSRF cookie '%s' not found in response headers after proxy", customCookieName) + + // Additionally: Verify session cookie is also present + foundSessionCookie := false + for _, cookie := range resp.Cookies() { + if cookie.Name == sessionCookieName { + foundSessionCookie = true + require.NotEmpty(t, cookie.Value, "Session cookie value should not be empty") + t.Logf("Found Session cookie via proxy: %s=%s", cookie.Name, cookie.Value) + break + } + } + require.True(t, foundSessionCookie, "Session cookie '%s' not found in response headers after proxy", sessionCookieName) +}