mirror of https://github.com/gofiber/fiber.git
463 lines
12 KiB
Go
463 lines
12 KiB
Go
package csrf
|
|
|
|
import (
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/utils"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
func Test_CSRF(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
app.Use(New())
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
|
|
|
|
for _, method := range methods {
|
|
// Generate CSRF token
|
|
ctx.Request.Header.SetMethod(method)
|
|
h(ctx)
|
|
|
|
// Without CSRF cookie
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Empty/invalid CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(HeaderName, "johndoe")
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Valid CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(method)
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(HeaderName, token)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
|
}
|
|
}
|
|
|
|
// go test -run Test_CSRF_Next
|
|
func Test_CSRF_Next(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
app.Use(New(Config{
|
|
Next: func(_ *fiber.Ctx) bool {
|
|
return true
|
|
},
|
|
}))
|
|
|
|
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
|
utils.AssertEqual(t, nil, err)
|
|
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
|
}
|
|
|
|
func Test_CSRF_Invalid_KeyLookup(t *testing.T) {
|
|
t.Parallel()
|
|
defer func() {
|
|
utils.AssertEqual(t, "[CSRF] KeyLookup must in the form of <source>:<key>", recover())
|
|
}()
|
|
app := fiber.New()
|
|
|
|
app.Use(New(Config{KeyLookup: "I:am:invalid"}))
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
h(ctx)
|
|
}
|
|
|
|
func Test_CSRF_From_Form(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
app.Use(New(Config{KeyLookup: "form:_csrf"}))
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Invalid CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
|
|
ctx.Request.SetBodyString("_csrf=" + token)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
|
}
|
|
|
|
func Test_CSRF_From_Query(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
app.Use(New(Config{KeyLookup: "query:_csrf"}))
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Invalid CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
ctx.Request.SetRequestURI("/")
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.SetRequestURI("/?_csrf=" + token)
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
|
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
|
}
|
|
|
|
func Test_CSRF_From_Param(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
csrfGroup := app.Group("/:csrf", New(Config{KeyLookup: "param:csrf"}))
|
|
|
|
csrfGroup.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Invalid CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.SetRequestURI("/" + utils.UUID())
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
ctx.Request.SetRequestURI("/" + utils.UUID())
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.SetRequestURI("/" + token)
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
|
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
|
}
|
|
|
|
func Test_CSRF_From_Cookie(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
csrfGroup := app.Group("/", New(Config{KeyLookup: "cookie:csrf"}))
|
|
|
|
csrfGroup.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Invalid CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.SetRequestURI("/")
|
|
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
ctx.Request.SetRequestURI("/")
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
|
|
ctx.Request.SetRequestURI("/")
|
|
h(ctx)
|
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
|
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
|
|
}
|
|
|
|
func Test_CSRF_From_Custom(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
extractor := func(c *fiber.Ctx) (string, error) {
|
|
body := string(c.Body())
|
|
// Generate the correct extractor to get the token from the correct location
|
|
selectors := strings.Split(body, "=")
|
|
|
|
if len(selectors) != 2 || selectors[1] == "" {
|
|
return "", errMissingParam
|
|
}
|
|
return selectors[1], nil
|
|
}
|
|
|
|
app.Use(New(Config{Extractor: extractor}))
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Invalid CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
|
|
ctx.Request.SetBodyString("_csrf=" + token)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
|
}
|
|
|
|
func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
errHandler := func(ctx *fiber.Ctx, err error) error {
|
|
utils.AssertEqual(t, errTokenNotFound, err)
|
|
return ctx.Status(419).Send([]byte("invalid CSRF token"))
|
|
}
|
|
|
|
app.Use(New(Config{ErrorHandler: errHandler}))
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
h(ctx)
|
|
|
|
// invalid CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(HeaderName, "johndoe")
|
|
h(ctx)
|
|
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
|
utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body()))
|
|
}
|
|
|
|
func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
errHandler := func(ctx *fiber.Ctx, err error) error {
|
|
utils.AssertEqual(t, errMissingHeader, err)
|
|
return ctx.Status(419).Send([]byte("empty CSRF token"))
|
|
}
|
|
|
|
app.Use(New(Config{ErrorHandler: errHandler}))
|
|
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
h(ctx)
|
|
|
|
// empty CSRF token
|
|
ctx.Request.Reset()
|
|
ctx.Response.Reset()
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
h(ctx)
|
|
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
|
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
|
|
}
|
|
|
|
// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
|
|
// func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
|
|
// t.Parallel()
|
|
// app := fiber.New()
|
|
|
|
// app.Use(New())
|
|
// app.Get("/", func(c *fiber.Ctx) error {
|
|
// return c.SendStatus(fiber.StatusOK)
|
|
// })
|
|
// app.Get("/test", func(c *fiber.Ctx) error {
|
|
// return c.SendStatus(fiber.StatusOK)
|
|
// })
|
|
// app.Post("/", func(c *fiber.Ctx) error {
|
|
// return c.SendStatus(fiber.StatusOK)
|
|
// })
|
|
|
|
// resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
|
// utils.AssertEqual(t, nil, err)
|
|
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
|
|
|
// var token string
|
|
// for _, c := range resp.Cookies() {
|
|
// if c.Name != ConfigDefault.CookieName {
|
|
// continue
|
|
// }
|
|
// token = c.Value
|
|
// break
|
|
// }
|
|
|
|
// fmt.Println("token", token)
|
|
|
|
// getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
// getReq.Header.Set(HeaderName, token)
|
|
// resp, err = app.Test(getReq)
|
|
|
|
// getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
|
|
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
|
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
|
|
// getReq.Header.Set(HeaderName, token)
|
|
|
|
// resp, err = app.Test(getReq)
|
|
|
|
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
|
|
// getReq.Header.Del(HeaderName)
|
|
// resp, err = app.Test(getReq)
|
|
|
|
// postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
|
|
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
|
// postReq.Header.Set(HeaderName, token)
|
|
// resp, err = app.Test(postReq)
|
|
// }
|
|
|
|
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
|
|
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
|
|
app := fiber.New()
|
|
|
|
app.Use(New())
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusTeapot)
|
|
})
|
|
|
|
fctx := &fasthttp.RequestCtx{}
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
h(ctx)
|
|
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
|
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
|
|
|
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
|
ctx.Request.Header.Set(HeaderName, token)
|
|
|
|
b.ReportAllocs()
|
|
b.ResetTimer()
|
|
|
|
for n := 0; n < b.N; n++ {
|
|
h(fctx)
|
|
}
|
|
|
|
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
|
}
|
|
|
|
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
|
|
func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
|
|
app := fiber.New()
|
|
|
|
app.Use(New())
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusTeapot)
|
|
})
|
|
|
|
fctx := &fasthttp.RequestCtx{}
|
|
h := app.Handler()
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Generate CSRF token
|
|
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
|
b.ReportAllocs()
|
|
b.ResetTimer()
|
|
|
|
for n := 0; n < b.N; n++ {
|
|
h(fctx)
|
|
}
|
|
|
|
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
|
}
|