fiber/middleware/session/middleware_test.go
Jason McNeil 670fbd5e45
feat: Add support for Keys() in session middleware (#3517)
*  feat: Update session middleware add Keys method and update docs to match key type any

*  test: Add /keys endpoint to set and retrieve multiple session keys

*  test: Refactor /keys endpoint to trim whitespace and improve key handling in session tests
2025-06-13 11:40:17 +02:00

519 lines
15 KiB
Go

package session
import (
"fmt"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_Session_Middleware(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/get", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
value, ok := sess.Get("key").(string)
if !ok {
return c.Status(fiber.StatusNotFound).SendString("key not found")
}
return c.SendString("value=" + value)
})
app.Post("/set", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// get a value from the body
value := c.FormValue("value")
sess.Set("key", value)
return c.SendStatus(fiber.StatusOK)
})
app.Post("/delete", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
sess.Delete("key")
return c.SendStatus(fiber.StatusOK)
})
app.Post("/reset", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if err := sess.Reset(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/destroy", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if err := sess.Destroy(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/fresh", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Reset the session to make it fresh
if err := sess.Reset(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if sess.Fresh() {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusInternalServerError)
})
app.Post("/keys", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// get a value from the body
value := c.FormValue("keys")
for _, rawKey := range strings.Split(value, ",") {
key := strings.TrimSpace(rawKey)
if key == "" {
continue
}
// Set each key in the session
sess.Set(key, "value_"+key)
}
return c.SendStatus(fiber.StatusOK)
})
app.Get("/keys", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
keys := sess.Keys()
if len(keys) == 0 {
return c.SendStatus(fiber.StatusNotFound)
}
// Keys may be of any type, so convert to string for display
strKeys := []string{}
for _, key := range keys {
strKeys = append(strKeys, fmt.Sprintf("%v", key))
}
return c.SendString("keys=" + strings.Join(strKeys, ","))
})
// Test GET, SET, DELETE, RESET, DESTROY by sending requests to the respective routes
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
h := app.Handler()
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "key not found", string(ctx.Response.Body()))
// Test POST /set
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/set")
ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type
ctx.Request.SetBodyString("value=hello")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /get to check if the value was set
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value=hello", string(ctx.Response.Body()))
// Test POST /delete to delete the value
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/delete")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /get to check if the value was deleted
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
require.Equal(t, "key not found", string(ctx.Response.Body()))
// Test POST /reset to reset the session
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/reset")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// verify we have a new session token
newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /destroy to destroy the session
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/destroy")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Verify the session cookie has expired
setCookieHeader := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.Contains(t, setCookieHeader, "max-age=0")
// Sleep so that the session expires
time.Sleep(1 * time.Second)
// Test GET /get to check if the session was destroyed
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
// check that we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
parts := strings.Split(newToken, ";")
require.Greater(t, len(parts), 1)
valueParts := strings.Split(parts[0], "=")
require.Greater(t, len(valueParts), 1)
newToken = valueParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /fresh to check if the session is fresh
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/fresh")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// check that we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /keys to set multiple keys
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/keys")
ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type
ctx.Request.SetBodyString("keys=key1,key2")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /keys to check if the session has the keys
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/keys")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
body := string(ctx.Response.Body())
require.True(t, strings.HasPrefix(body, "keys="))
parts = strings.Split(strings.TrimPrefix(body, "keys="), ",")
require.Len(t, parts, 2, "Expected two keys in the session")
sort.Strings(parts)
require.Equal(t, []string{"key1", "key2"}, parts)
}
func Test_Session_NewWithStore(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
return c.SendString("value=" + id)
})
app.Post("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
c.Cookie(&fiber.Cookie{
Name: "session_id",
Value: id,
})
return nil
})
h := app.Handler()
// Test GET request without cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test GET request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value="+token, string(ctx.Response.Body()))
}
func Test_Session_FromSession(t *testing.T) {
t.Parallel()
app := fiber.New()
sess := FromContext(app.AcquireCtx(&fasthttp.RequestCtx{}))
require.Nil(t, sess)
app.Use(New())
}
func Test_Session_WithConfig(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Get("key") == "value"
},
IdleTimeout: 1 * time.Second,
KeyLookup: "cookie:session_id_test",
KeyGenerator: func() string {
return "test"
},
source: "cookie_test",
sessionName: "session_id_test",
}))
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
return c.SendString("value=" + id)
})
app.Get("/isFresh", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess.Fresh() {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusInternalServerError)
})
app.Post("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
c.Cookie(&fiber.Cookie{
Name: "session_id_test",
Value: id,
})
return nil
})
h := app.Handler()
// Test GET request without cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test GET request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test POST request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id_test", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request without cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request with wrong key
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request with wrong value
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id_test", "wrong")
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Check idle timeout not expired
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
ctx.Request.SetRequestURI("/isFresh")
h(ctx)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode())
// Test idle timeout
time.Sleep(1200 * time.Millisecond)
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
ctx.Request.SetRequestURI("/isFresh")
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}
func Test_Session_Next(t *testing.T) {
t.Parallel()
var (
doNext bool
muNext sync.RWMutex
)
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
muNext.RLock()
defer muNext.RUnlock()
return doNext
},
}))
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
id := sess.ID()
return c.SendString("value=" + id)
})
h := app.Handler()
// Test with Next returning false
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test with Next returning true
muNext.Lock()
doNext = true
muNext.Unlock()
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode())
}
func Test_Session_Middleware_Store(t *testing.T) {
t.Parallel()
app := fiber.New()
handler, sessionStore := NewWithStore()
app.Use(handler)
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
st := sess.Store()
if st != sessionStore {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
// Test GET request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}