package csrf import ( "net/http/httptest" "strings" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_CSRF(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} for _, method := range methods { // Generate CSRF token ctx.Request.Header.SetMethod(method) h(ctx) // Without CSRF cookie ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Valid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(method) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } } func Test_CSRF_WithSession(t *testing.T) { t.Parallel() // session store store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) // fiber instance app := fiber.New() // fiber context ctx := &fasthttp.RequestCtx{} defer app.ReleaseCtx(app.AcquireCtx(ctx)) // get session sess, err := store.Get(app.AcquireCtx(ctx)) require.NoError(t, err) require.True(t, sess.Fresh()) // the session string is no longer be 123 newSessionIDString := sess.ID() require.NoError(t, sess.Save()) app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config config := Config{ Session: store, } // middleware app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} for _, method := range methods { // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) // Without CSRF cookie ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Empty/invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Valid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(method) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) for _, header := range strings.Split(token, ";") { if strings.Split(utils.Trim(header, ' '), "=")[0] == ConfigDefault.CookieName { token = strings.Split(header, "=")[1] break } } ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } } // go test -run Test_CSRF_WithSession_Middleware func Test_CSRF_WithSession_Middleware(t *testing.T) { t.Parallel() app := fiber.New() // session mw smh, sstore := session.NewWithStore() // csrf mw cmh := New(Config{ Session: sstore, }) app.Use(smh) app.Use(cmh) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) sess.Set("hello", "world") return c.SendStatus(fiber.StatusOK) }) app.Post("/", func(c fiber.Ctx) error { sess := session.FromContext(c) if sess.Get("hello") != "world" { return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token and session_id ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) csrfTokenParts := strings.Split(string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)), ";") require.Greater(t, len(csrfTokenParts), 2) csrfToken := strings.Split(csrfTokenParts[0], "=")[1] require.NotEmpty(t, csrfToken) sessionID := strings.Split(csrfTokenParts[1], "=")[1] require.NotEmpty(t, sessionID) // Use the CSRF token and session_id ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, csrfToken) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, csrfToken) ctx.Request.Header.SetCookie("session_id", sessionID) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } // go test -run Test_CSRF_ExpiredToken func Test_CSRF_ExpiredToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ IdleTimeout: 1 * time.Second, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Use the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Wait for the token to expire time.Sleep(1250 * time.Millisecond) // Expired CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // go test -run Test_CSRF_ExpiredToken_WithSession func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { t.Parallel() // session store store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) // fiber instance app := fiber.New() // fiber context ctx := &fasthttp.RequestCtx{} defer app.ReleaseCtx(app.AcquireCtx(ctx)) // get session sess, err := store.Get(app.AcquireCtx(ctx)) require.NoError(t, err) require.True(t, sess.Fresh()) // get session id newSessionIDString := sess.ID() require.NoError(t, sess.Save()) app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config config := Config{ Session: store, IdleTimeout: 1 * time.Second, } // middleware app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) for _, header := range strings.Split(token, ";") { if strings.Split(utils.Trim(header, ' '), "=")[0] == ConfigDefault.CookieName { token = strings.Split(header, "=")[1] break } } // Use the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Wait for the token to expire time.Sleep(1*time.Second + 100*time.Millisecond) // Expired CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // go test -run Test_CSRF_MultiUseToken func Test_CSRF_MultiUseToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ KeyLookup: "header:X-Csrf-Token", })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", "johndoe") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] require.Equal(t, 200, ctx.Response.StatusCode()) // Check if the token is not a dummy value require.Equal(t, token, newToken) } // go test -run Test_CSRF_SingleUseToken func Test_CSRF_SingleUseToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ SingleUseToken: true, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Use the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] if token == newToken { t.Error("new token should not be the same as the old token") } // Use the CSRF token again ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // go test -run Test_CSRF_Next func Test_CSRF_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Test_CSRF_Invalid_KeyLookup(t *testing.T) { t.Parallel() defer func() { require.Equal(t, "[CSRF] KeyLookup must in the form of :", recover()) }() app := fiber.New() app.Use(New(Config{KeyLookup: "I:am:invalid"})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) } func Test_CSRF_From_Form(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{KeyLookup: "form:_csrf"})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } func Test_CSRF_From_Query(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{KeyLookup: "query:_csrf"})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/?_csrf=" + utils.UUIDv4()) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/?_csrf=" + token) ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Param(t *testing.T) { t.Parallel() app := fiber.New() csrfGroup := app.Group("/:csrf", New(Config{KeyLookup: "param:csrf"})) csrfGroup.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/" + utils.UUIDv4()) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/" + utils.UUIDv4()) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/" + token) ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Cookie(t *testing.T) { t.Parallel() app := fiber.New() csrfGroup := app.Group("/", New(Config{KeyLookup: "cookie:csrf"})) csrfGroup.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUIDv4()+";") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";") ctx.Request.SetRequestURI("/") h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Custom(t *testing.T) { t.Parallel() app := fiber.New() extractor := func(c fiber.Ctx) (string, error) { body := string(c.Body()) // Generate the correct extractor to get the token from the correct location selectors := strings.Split(body, "=") if len(selectors) != 2 || selectors[1] == "" { return "", ErrMissingParam } return selectors[1], nil } app.Use(New(Config{Extractor: extractor})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } func Test_CSRF_Extractor_EmptyString(t *testing.T) { t.Parallel() app := fiber.New() extractor := func(_ fiber.Ctx) (string, error) { return "", nil } errorHandler := func(c fiber.Ctx, err error) error { return c.Status(403).SendString(err.Error()) } app.Use(New(Config{ Extractor: extractor, ErrorHandler: errorHandler, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body())) } func Test_CSRF_Origin(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{CookieSecure: true})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Correct Origin with port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com:8080") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com:8080") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:8080") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Origin with wrong port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:3000") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Correct Origin with null ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "null") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Origin with ReverseProxy ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("10.0.1.42.com:8080") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("10.0.1.42:8080") ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderXForwardedFor, `192.0.2.43, "[2001:db8:cafe::17]"`) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Origin with ReverseProxy Missing X-Forwarded-* Headers ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("10.0.1.42:8080") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("10.0.1.42:8080") ctx.Request.Header.Set(fiber.HeaderXUrlScheme, "http") // We need to set this header to make sure c.Protocol() returns http ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Wrong Origin ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://csrf.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_TrustedOrigins(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CookieSecure: true, TrustedOrigins: []string{ "http://safe.example.com", "https://safe.example.com", "http://*.domain-1.com", "https://*.domain-1.com", }, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Trusted Origin ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://safe.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Origin Subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://safe.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Origin deeply nested subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("a.b.c.domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("a.b.c.domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "https://a.b.c.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Origin Invalid ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evildomain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Trusted Referer ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://safe.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Referer Wildcard ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://safe.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Referer deeply nested subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("a.b.c.domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("a.b.c.domain-1.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://a.b.c.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Referer Invalid ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("api.domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("api.domain-1.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://evildomain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_TrustedOrigins_InvalidOrigins(t *testing.T) { t.Parallel() tests := []struct { name string origin string }{ {name: "No Scheme", origin: "localhost"}, {name: "Wildcard", origin: "https://*"}, {name: "Wildcard domain", origin: "https://*example.com"}, {name: "File Scheme", origin: "file://example.com"}, {name: "FTP Scheme", origin: "ftp://example.com"}, {name: "Port Wildcard", origin: "http://example.com:*"}, {name: "Multiple Wildcards", origin: "https://*.*.com"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { origin := tt.origin t.Parallel() require.Panics(t, func() { app := fiber.New() app.Use(New(Config{ CookieSecure: true, TrustedOrigins: []string{origin}, })) }, "Expected panic") }) } } func Test_CSRF_Referer(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{CookieSecure: true})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Correct Referer with port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com:8443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com:8443") ctx.Request.Header.Set(fiber.HeaderReferer, ctx.Request.URI().String()) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Referer with ReverseProxy ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("10.0.1.42.com:8443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("10.0.1.42:8443") ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderXForwardedFor, `192.0.2.43, "[2001:db8:cafe::17]"`) ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Referer with ReverseProxy Missing X-Forwarded-* Headers ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("10.0.1.42:8443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("10.0.1.42:8443") ctx.Request.Header.Set(fiber.HeaderXUrlScheme, "https") // We need to set this header to make sure c.Protocol() returns https ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Correct Referer with path ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com/action/items?gogogo=true") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Wrong Referer ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://csrf.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_DeleteToken(t *testing.T) { t.Parallel() app := fiber.New() config := ConfigDefault app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // DeleteToken after token generation and remove the cookie ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.Set(HeaderName, "") handler := HandlerFromContext(app.AcquireCtx(ctx)) if handler != nil { ctx.Request.Header.DelAllCookies() err := handler.DeleteToken(app.AcquireCtx(ctx)) require.ErrorIs(t, err, ErrTokenNotFound) } h(ctx) // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Delete the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) handler = HandlerFromContext(app.AcquireCtx(ctx)) if handler != nil { if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil { t.Fatal(err) } } h(ctx) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_DeleteToken_WithSession(t *testing.T) { t.Parallel() // session store store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) // fiber instance app := fiber.New() // fiber context ctx := &fasthttp.RequestCtx{} // get session sess, err := store.Get(app.AcquireCtx(ctx)) require.NoError(t, err) require.True(t, sess.Fresh()) // the session string is no longer be 123 newSessionIDString := sess.ID() require.NoError(t, sess.Save()) app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config config := Config{ Session: store, } // middleware app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Delete the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) handler := HandlerFromContext(app.AcquireCtx(ctx)) if handler != nil { if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil { t.Fatal(err) } } h(ctx) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx fiber.Ctx, err error) error { require.Equal(t, ErrTokenInvalid, err) return ctx.Status(419).Send([]byte("invalid CSRF token")) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) // invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, "invalid CSRF token", string(ctx.Response.Body())) } func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx fiber.Ctx, err error) error { require.Equal(t, ErrMissingHeader, err) return ctx.Status(419).Send([]byte("empty CSRF token")) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) // empty CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, "empty CSRF token", string(ctx.Response.Body())) } func Test_CSRF_ErrorHandler_MissingReferer(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx fiber.Ctx, err error) error { require.Equal(t, ErrRefererNotFound, err) return ctx.Status(419).Send([]byte("empty CSRF token")) } app.Use(New(Config{ CookieSecure: true, ErrorHandler: errHandler, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) } func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Inject CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderCookie, "csrf_=pwned;") ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Exploit CSRF token we just injected ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.Set(fiber.HeaderCookie, "csrf_=pwned;") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode(), "CSRF exploit successful") } // TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase func Test_CSRF_UnsafeHeaderValue(t *testing.T) { t.SkipNow() t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) app.Get("/test", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) var token string for _, c := range resp.Cookies() { if c.Name != ConfigDefault.CookieName { continue } token = c.Value break } t.Log("token", token) getReq := httptest.NewRequest(fiber.MethodGet, "/", nil) getReq.Header.Set(HeaderName, token) resp, err = app.Test(getReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil) getReq.Header.Set("X-Requested-With", "XMLHttpRequest") getReq.Header.Set(fiber.HeaderCacheControl, "no") getReq.Header.Set(HeaderName, token) resp, err = app.Test(getReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) getReq.Header.Set(fiber.HeaderAccept, "*/*") getReq.Header.Del(HeaderName) resp, err = app.Test(getReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) postReq := httptest.NewRequest(fiber.MethodPost, "/", nil) postReq.Header.Set("X-Requested-With", "XMLHttpRequest") postReq.Header.Set(HeaderName, token) resp, err = app.Test(postReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4 func Benchmark_Middleware_CSRF_Check(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Correct Referer POST ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h(ctx) } require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode()) } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h(ctx) } // Ensure the GET request returns a 418 status code require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode()) } func Test_CSRF_InvalidURLHeaders(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx fiber.Ctx, err error) error { return ctx.Status(419).Send([]byte(err.Error())) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // invalid Origin ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://[::1]:%38%30/Invalid Origin") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, ErrOriginInvalid.Error(), string(ctx.Response.Body())) // invalid Referer ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "http://[::1]:%38%30/Invalid Referer") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, ErrRefererInvalid.Error(), string(ctx.Response.Body())) } func Test_CSRF_TokenFromContext(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { token := TokenFromContext(c) require.NotEmpty(t, token) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_CSRF_FromContextMethods(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { token := TokenFromContext(c) require.NotEmpty(t, token) handler := HandlerFromContext(c) require.NotNil(t, handler) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_CSRF_FromContextMethods_Invalid(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { token := TokenFromContext(c) require.Empty(t, token) handler := HandlerFromContext(c) require.Nil(t, handler) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) }