package csrf import ( "net/http/httptest" "strings" "testing" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" ) func Test_CSRF(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} for _, method := range methods { // Generate CSRF token ctx.Request.Header.SetMethod(method) h(ctx) // Without CSRF cookie ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Empty/invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Valid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(method) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) } } // go test -run Test_CSRF_Next func Test_CSRF_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ *fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) } func Test_CSRF_Invalid_KeyLookup(t *testing.T) { t.Parallel() defer func() { utils.AssertEqual(t, "[CSRF] KeyLookup must in the form of :", recover()) }() app := fiber.New() app.Use(New(Config{KeyLookup: "I:am:invalid"})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) } func Test_CSRF_From_Form(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{KeyLookup: "form:_csrf"})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.SetBodyString("_csrf=" + token) h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) } func Test_CSRF_From_Query(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{KeyLookup: "query:_csrf"})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID()) h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/?_csrf=" + token) ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Param(t *testing.T) { t.Parallel() app := fiber.New() csrfGroup := app.Group("/:csrf", New(Config{KeyLookup: "param:csrf"})) csrfGroup.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/" + utils.UUID()) h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/" + utils.UUID()) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/" + token) ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Cookie(t *testing.T) { t.Parallel() app := fiber.New() csrfGroup := app.Group("/", New(Config{KeyLookup: "cookie:csrf"})) csrfGroup.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";") h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";") ctx.Request.SetRequestURI("/") h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) utils.AssertEqual(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Custom(t *testing.T) { t.Parallel() app := fiber.New() extractor := func(c *fiber.Ctx) (string, error) { body := string(c.Body()) // Generate the correct extractor to get the token from the correct location selectors := strings.Split(body, "=") if len(selectors) != 2 || selectors[1] == "" { return "", errMissingParam } return selectors[1], nil } app.Use(New(Config{Extractor: extractor})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.SetBodyString("_csrf=" + token) h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) } func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx *fiber.Ctx, err error) error { utils.AssertEqual(t, errTokenNotFound, err) return ctx.Status(419).Send([]byte("invalid CSRF token")) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) // invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) utils.AssertEqual(t, 419, ctx.Response.StatusCode()) utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body())) } func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx *fiber.Ctx, err error) error { utils.AssertEqual(t, errMissingHeader, err) return ctx.Status(419).Send([]byte("empty CSRF token")) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) // empty CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) utils.AssertEqual(t, 419, ctx.Response.StatusCode()) utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body())) } // TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase // func Test_CSRF_UnsafeHeaderValue(t *testing.T) { // t.Parallel() // app := fiber.New() // app.Use(New()) // app.Get("/", func(c *fiber.Ctx) error { // return c.SendStatus(fiber.StatusOK) // }) // app.Get("/test", func(c *fiber.Ctx) error { // return c.SendStatus(fiber.StatusOK) // }) // app.Post("/", func(c *fiber.Ctx) error { // return c.SendStatus(fiber.StatusOK) // }) // resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) // utils.AssertEqual(t, nil, err) // utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) // var token string // for _, c := range resp.Cookies() { // if c.Name != ConfigDefault.CookieName { // continue // } // token = c.Value // break // } // fmt.Println("token", token) // getReq := httptest.NewRequest(fiber.MethodGet, "/", nil) // getReq.Header.Set(HeaderName, token) // resp, err = app.Test(getReq) // getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil) // getReq.Header.Set("X-Requested-With", "XMLHttpRequest") // getReq.Header.Set(fiber.HeaderCacheControl, "no") // getReq.Header.Set(HeaderName, token) // resp, err = app.Test(getReq) // getReq.Header.Set(fiber.HeaderAccept, "*/*") // getReq.Header.Del(HeaderName) // resp, err = app.Test(getReq) // postReq := httptest.NewRequest(fiber.MethodPost, "/", nil) // postReq.Header.Set("X-Requested-With", "XMLHttpRequest") // postReq.Header.Set(HeaderName, token) // resp, err = app.Test(postReq) // } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4 func Benchmark_Middleware_CSRF_Check(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) fctx := &fasthttp.RequestCtx{} h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h(fctx) } utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) fctx := &fasthttp.RequestCtx{} h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h(fctx) } utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) }