🔥 Feature: Enhance Session Middleware Context Handling

pull/3339/head
JIeJaitt 2025-03-11 17:43:30 +08:00
parent dd222a1e82
commit 71d9178f8f
4 changed files with 208 additions and 40 deletions

View File

@ -8,6 +8,7 @@ import (
"sync"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)
// Middleware holds session data and configuration.
@ -92,7 +93,7 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) {
m.initialize(c, cfg)
// Add session to Go context
ctx := context.WithValue(c.Context(), sessionContextKey, m.Session)
ctx := context.WithValue(c.Context(), middlewareContextKey, m)
c.SetContext(ctx)
stackErr := c.Next()
@ -127,7 +128,7 @@ func (m *Middleware) initialize(c fiber.Ctx, cfg Config) {
m.ctx = c
c.Locals(middlewareContextKey, m)
c.SetContext(context.WithValue(c.Context(), sessionContextKey, session))
c.SetContext(context.WithValue(c.Context(), middlewareContextKey, session))
}
// saveSession handles session saving and error management after the response.
@ -170,23 +171,25 @@ func releaseMiddleware(m *Middleware) {
middlewarePool.Put(m)
}
// FromContext returns the Middleware from the Fiber context.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Middleware: The middleware object if found, otherwise nil.
//
// Usage:
//
// m := session.FromContext(c)
func FromContext(c fiber.Ctx) *Middleware {
m, ok := c.Locals(middlewareContextKey).(*Middleware)
if !ok {
return nil
// FromContext returns the session from context.
// Supported context types:
// - fiber.Ctx: Retrieves session from Locals
// - context.Context: Retrieves session from context values
// If there is no session, nil is returned.
func FromContext(c any) *Middleware {
switch ctx := c.(type) {
case fiber.Ctx:
if m, ok := ctx.Locals(middlewareContextKey).(*Middleware); ok && m != nil {
return m
}
case context.Context:
if m, ok := ctx.Value(middlewareContextKey).(*Middleware); ok {
return m
}
default:
log.Errorf("Unsupported context type: %T. Expected fiber.Ctx or context.Context", c)
}
return m
return nil
}
// Set sets a key-value pair in the session.

View File

@ -455,7 +455,7 @@ func Test_Session_Middleware_Store(t *testing.T) {
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}
func Test_Session_GoContext(t *testing.T) {
func Test_Session_Context(t *testing.T) {
t.Parallel()
app := fiber.New()
@ -474,7 +474,7 @@ func Test_Session_GoContext(t *testing.T) {
sess.Set("test_key", "test_value")
// Get the session from Go context
goCtxSess := FromGoContext(c.Context())
goCtxSess := FromContext(c.Context())
// Verify both sessions are the same
if goCtxSess == nil {
@ -506,11 +506,11 @@ func Test_Session_GoContext_EdgeCases(t *testing.T) {
t.Parallel()
// Test with nil context
sess := FromGoContext(nil) //nolint:staticcheck // Intentionally testing nil context behavior
sess := FromContext(nil) //nolint:staticcheck // Intentionally testing nil context behavior
require.Nil(t, sess, "Session should be nil with nil context")
// Test with context that doesn't have a session
ctx := context.Background()
sess = FromGoContext(ctx)
sess = FromContext(ctx)
require.Nil(t, sess, "Session should be nil when not in context")
}

View File

@ -31,15 +31,6 @@ const (
absExpirationKey absExpirationKeyType = iota
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
// sessionContextKey is the key used to store the *Session in the Go context.
sessionContextKey contextKey = iota
)
// Session pool for reusing byte buffers.
var byteBufferPool = sync.Pool{
New: func() any {
@ -522,14 +513,14 @@ func (s *Session) setAbsExpiration(absExpiration time.Time) {
s.Set(absExpirationKey, absExpiration)
}
// FromGoContext returns the Session from the go context.
// If there is no session, nil is returned.
func FromGoContext(ctx context.Context) *Session {
if ctx == nil {
return nil
// GetAndSetInContext Get the session and store it in the Go context
func (s *Store) GetAndSetInContext(c fiber.Ctx) (*Session, error) {
sess, err := s.Get(c)
if err != nil {
return nil, err
}
if sess, ok := ctx.Value(sessionContextKey).(*Session); ok {
return sess
}
return nil
// Add to Go context
c.SetContext(context.WithValue(c.Context(), middlewareContextKey, sess))
return sess, nil
}

View File

@ -1316,3 +1316,177 @@ func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) {
// Check that the error message is as expected
require.ErrorContains(t, err, "failed to decode session data", "Unexpected error")
}
// Test_Session_GetAndSetInContext tests the functionality of GetAndSetInContext method
// which retrieves a session and stores it in the context.
func Test_Session_GetAndSetInContext(t *testing.T) {
t.Parallel()
t.Run("new_session_creation_and_context_storage", func(t *testing.T) {
t.Parallel()
// Setup test environment
store := NewStore()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Test creating a new session and storing in context
sess, err := store.GetAndSetInContext(ctx)
// Verify operation success and session creation
require.NoError(t, err, "should not error when creating new session")
require.NotNil(t, sess, "session should not be nil")
require.True(t, sess.Fresh(), "new session should be marked as fresh")
// Capture session ID for reference
sessionID := sess.ID()
require.NotEmpty(t, sessionID, "session ID should not be empty")
// Verify session was properly set in Go context
// The session is stored directly in the context, not as middleware
val := ctx.Context().Value(middlewareContextKey)
sessionFromCtx, ok := val.(*Session)
if !ok {
t.Fatalf("session should be accessible from context")
}
require.Equal(t, sessionID, sessionFromCtx.ID(), "session ID should match between original and context-retrieved session")
// Properly save and release resources
require.NoError(t, sess.Save(), "session save should succeed")
sess.Release()
})
t.Run("session_retrieval_and_data_persistence", func(t *testing.T) {
t.Parallel()
// Setup test environment
store := NewStore()
app := fiber.New()
// Step 1: Create initial session
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.GetAndSetInContext(ctx)
require.NoError(t, err, "should create initial session without error")
// Set test data and save
sessionID := sess.ID()
sess.Set("testKey", "testValue")
require.NoError(t, sess.Save(), "should save session without error")
sess.Release()
app.ReleaseCtx(ctx)
// Step 2: Retrieve session in a new context using cookie
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie(store.sessionName, sessionID)
// Get and verify session
retrievedSess, err := store.GetAndSetInContext(ctx)
defer retrievedSess.Release()
// Verify operation success and correct session retrieval
require.NoError(t, err, "should retrieve existing session without error")
require.NotNil(t, retrievedSess, "retrieved session should not be nil")
require.False(t, retrievedSess.Fresh(), "retrieved session should not be fresh")
require.Equal(t, sessionID, retrievedSess.ID(), "session ID should be preserved")
// Verify data persistence
require.Equal(t, "testValue", retrievedSess.Get("testKey"), "session data should persist between requests")
// Verify session in context matches direct retrieval
val := ctx.Context().Value(middlewareContextKey)
sessionFromCtx, ok := val.(*Session)
if !ok {
t.Fatalf("session should be accessible from context")
}
require.NotNil(t, sessionFromCtx, "session should exist in context")
require.Equal(t, sessionID, sessionFromCtx.ID(), "session IDs should match")
require.Equal(t, "testValue", sessionFromCtx.Get("testKey"), "session data should be accessible from context")
})
t.Run("session_data_modification", func(t *testing.T) {
t.Parallel()
// Setup test environment
store := NewStore()
app := fiber.New()
// Create and save initial session
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.GetAndSetInContext(ctx)
require.NoError(t, err, "should create session without error")
sessionID := sess.ID()
require.NoError(t, sess.Save(), "should save session without error")
sess.Release()
app.ReleaseCtx(ctx)
// Retrieve and modify session
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie(store.sessionName, sessionID)
sess, err = store.GetAndSetInContext(ctx)
require.NoError(t, err, "should retrieve session without error")
// Modify session data
sess.Set("modifiedKey", "modifiedValue")
require.NoError(t, sess.Save(), "should save modified session without error")
// Verify context has updated data immediately
val := ctx.Context().Value(middlewareContextKey)
sessionFromCtx, ok := val.(*Session)
if !ok {
t.Fatalf("session should be accessible from context")
}
require.Equal(t, "modifiedValue", sessionFromCtx.Get("modifiedKey"), "session in context should have updated data")
sess.Release()
})
}
// go test -run Test_Session_GetAndSetInContext_Error
func Test_Session_GetAndSetInContext_Error(t *testing.T) {
t.Parallel()
// Create a new store with mock storage that returns an error
mockStore := NewStore(Config{
Storage: &errorStorage{},
})
// Create a new fiber instance
app := fiber.New()
// Create a new fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Attempt to get and set session in context
sess, err := mockStore.GetAndSetInContext(ctx)
require.Error(t, err, "should return error when storage fails")
require.Nil(t, sess, "session should be nil on error")
require.Contains(t, err.Error(), "mock error", "error should be from storage")
}
// errorStorage implements Storage interface for error testing
type errorStorage struct{}
func (*errorStorage) Get(string) ([]byte, error) {
return nil, errors.New("mock error: Get failed")
}
func (*errorStorage) Set(string, []byte, time.Duration) error {
return errors.New("mock error: Set failed")
}
func (*errorStorage) Delete(string) error {
return errors.New("mock error: Delete failed")
}
func (*errorStorage) Reset() error {
return errors.New("mock error: Reset failed")
}
func (*errorStorage) Close() error {
return errors.New("mock error: Close failed")
}