fiber/middleware/session/extractors_test.go
Jason McNeil 979e7cd6b1
feat(middleware/session): Introduce Extractor pattern for session ID retrieval (#3625)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-07-28 16:48:22 +02:00

234 lines
5.7 KiB
Go

package session
import (
"context"
"net/http"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// newRequest creates a new *http.Request for Fiber's app.Test
func newRequest(method, target string) *http.Request {
req, err := http.NewRequestWithContext(context.Background(), method, target, nil)
if err != nil {
panic(err)
}
return req
}
func TestFromCookie(t *testing.T) {
t.Parallel()
app := fiber.New()
extractor := FromCookie("session_id")
t.Run("success", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie("session_id", "test-session-id")
sessionID, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "test-session-id", sessionID)
})
t.Run("missing cookie", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
sessionID, err := extractor.Extract(ctx)
require.Error(t, err)
require.Equal(t, ErrMissingSessionIDInCookie, err)
require.Empty(t, sessionID)
})
}
func TestFromHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
extractor := FromHeader("X-Session-ID")
t.Run("success", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.Set("X-Session-ID", "test-session-id")
sessionID, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "test-session-id", sessionID)
})
t.Run("missing header", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
sessionID, err := extractor.Extract(ctx)
require.Error(t, err)
require.Equal(t, ErrMissingSessionIDInHeader, err)
require.Empty(t, sessionID)
})
}
func TestFromQuery(t *testing.T) {
t.Parallel()
app := fiber.New()
extractor := FromQuery("session_id")
t.Run("success", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().SetRequestURI("/test?session_id=test-session-id")
sessionID, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "test-session-id", sessionID)
})
t.Run("missing query param", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().SetRequestURI("/test")
sessionID, err := extractor.Extract(ctx)
require.Error(t, err)
require.Equal(t, ErrMissingSessionIDInQuery, err)
require.Empty(t, sessionID)
})
}
func TestFromForm(t *testing.T) {
t.Parallel()
app := fiber.New()
extractor := FromForm("session_id")
t.Run("success", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetMethod("POST")
ctx.Request().Header.SetContentType("application/x-www-form-urlencoded")
ctx.Request().SetBodyString("session_id=test-session-id")
sessionID, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "test-session-id", sessionID)
})
t.Run("missing form field", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetMethod("POST")
ctx.Request().Header.SetContentType("application/x-www-form-urlencoded")
ctx.Request().SetBodyString("other_field=value")
sessionID, err := extractor.Extract(ctx)
require.Error(t, err)
require.Equal(t, ErrMissingSessionIDInForm, err)
require.Empty(t, sessionID)
})
}
func TestFromParam(t *testing.T) {
t.Parallel()
app := fiber.New()
// FromParam
app.Get("/test/:csrf", func(c fiber.Ctx) error {
token, err := FromParam("csrf").Extract(c)
require.NoError(t, err)
require.Equal(t, "token_from_param", token)
return nil
})
// Note: This test is more complex as it requires route setup
// In a real scenario, you'd set up a route with parameters
t.Run("success", func(t *testing.T) {
t.Parallel()
_, err := app.Test(newRequest(fiber.MethodGet, "/test/token_from_param"))
require.NoError(t, err)
})
}
func TestChain(t *testing.T) {
t.Parallel()
app := fiber.New()
t.Run("first extractor succeeds", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie("session_id", "cookie-session-id")
ctx.Request().Header.Set("X-Session-ID", "header-session-id")
extractor := Chain(
FromCookie("session_id"),
FromHeader("X-Session-ID"),
)
sessionID, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "cookie-session-id", sessionID) // First extractor wins
})
t.Run("second extractor succeeds", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.Set("X-Session-ID", "header-session-id")
extractor := Chain(
FromCookie("session_id"),
FromHeader("X-Session-ID"),
)
sessionID, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "header-session-id", sessionID)
})
t.Run("all extractors fail", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
extractor := Chain(
FromCookie("session_id"),
FromHeader("X-Session-ID"),
)
sessionID, err := extractor.Extract(ctx)
require.Error(t, err)
require.Empty(t, sessionID)
})
t.Run("empty chain", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
extractor := Chain()
sessionID, err := extractor.Extract(ctx)
require.Error(t, err)
require.Equal(t, ErrMissingSessionID, err)
require.Empty(t, sessionID)
})
}