🩹 fix test cases

pull/958/head
Fenny 2020-10-24 01:00:09 +02:00
parent 3f7b80e9a6
commit 9f2c0691b0
3 changed files with 96 additions and 31 deletions

View File

@ -34,7 +34,7 @@ app.Use(csrf.New(csrf.Config{
Cookie: &fiber.Cookie{
Name: "_csrf",
},
CookieExpires: 24 * time.Hour,
Expiration: 24 * time.Hour,
}))
```
@ -63,10 +63,10 @@ type Config struct {
// Optional.
Cookie *fiber.Cookie
// CookieExpires is the duration before the cookie will expire
// Expiration is the duration before csrf token will expire
//
// Optional. Default: 24 * time.Hour
CookieExpires time.Duration
Expiration time.Duration
// Context key to store generated CSRF token into context.
//
@ -83,11 +83,8 @@ var ConfigDefault = Config{
ContextKey: "csrf",
Cookie: &fiber.Cookie{
Name: "_csrf",
Domain: "",
Path: "",
Secure: false,
HTTPOnly: false,
SameSite: "Strict",
},
CookieExpires: 24 * time.Hour,
Expiration: 24 * time.Hour,
}
```

View File

@ -84,7 +84,7 @@ func New(config ...Config) fiber.Handler {
cfg.ContextKey = ConfigDefault.ContextKey
}
if cfg.CookieExpires != 0 {
fmt.Println("CookieExpires is deprecated, please use Expiration")
fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration")
cfg.CookieExpires = ConfigDefault.Expiration
}
if cfg.Expiration == 0 {
@ -183,7 +183,6 @@ func New(config ...Config) fiber.Handler {
db.RLock()
t, ok := db.tokens[csrf]
db.RUnlock()
// Check if token exist or expired
if !ok || time.Now().Unix() >= t {
return fiber.ErrForbidden
@ -204,7 +203,6 @@ func New(config ...Config) fiber.Handler {
// Set cookie to response
c.Cookie(cookie)
// Store token in context
c.Locals(cfg.ContextKey, token)

View File

@ -25,7 +25,8 @@ func Test_CSRF(t *testing.T) {
// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
h(ctx)
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)), "_csrf"))
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Without CSRF cookie
ctx.Request.Reset()
@ -43,11 +44,14 @@ func Test_CSRF(t *testing.T) {
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Valid CSRF token
token := utils.UUID()
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
h(ctx)
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderCookie, "_csrf="+token)
ctx.Request.Header.Set("X-CSRF-Token", token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
@ -97,17 +101,21 @@ func Test_CSRF_From_Form(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Valid CSRF token
token := utils.UUID()
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderCookie, "_csrf="+token)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderCookie, "_csrf="+token)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("_csrf=" + token)
h(ctx)
@ -126,19 +134,28 @@ func Test_CSRF_From_Query(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Valid CSRF token
token := utils.UUID()
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderCookie, "_csrf="+token)
ctx.Request.SetRequestURI("/?_csrf=" + token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
ctx.Request.SetRequestURI("/")
ctx.Response.Reset()
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
utils.AssertEqual(t, "Forbidden", string(ctx.Response.Body()))
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/?_csrf=" + token)
ctx.Request.Header.SetMethod("POST")
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
}
func Test_CSRF_From_Param(t *testing.T) {
@ -153,11 +170,64 @@ func Test_CSRF_From_Param(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Valid CSRF token
token := utils.UUID()
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderCookie, "_csrf="+token)
ctx.Request.SetRequestURI("/" + utils.UUID())
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/" + utils.UUID())
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/" + token)
ctx.Request.Header.SetMethod("POST")
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
}
func Test_CSRF_From_Cookie(t *testing.T) {
app := fiber.New()
csrfGroup := app.Group("/", New(Config{TokenLookup: "cookie:csrf"}))
csrfGroup.Post("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
ctx.Request.SetRequestURI("/")
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
}