mirror of https://github.com/gofiber/fiber.git
🔥 Feature: Enhance Session Middleware Context Handling
parent
dd222a1e82
commit
71d9178f8f
|
@ -8,6 +8,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/log"
|
||||
)
|
||||
|
||||
// Middleware holds session data and configuration.
|
||||
|
@ -92,7 +93,7 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) {
|
|||
m.initialize(c, cfg)
|
||||
|
||||
// Add session to Go context
|
||||
ctx := context.WithValue(c.Context(), sessionContextKey, m.Session)
|
||||
ctx := context.WithValue(c.Context(), middlewareContextKey, m)
|
||||
c.SetContext(ctx)
|
||||
|
||||
stackErr := c.Next()
|
||||
|
@ -127,7 +128,7 @@ func (m *Middleware) initialize(c fiber.Ctx, cfg Config) {
|
|||
m.ctx = c
|
||||
|
||||
c.Locals(middlewareContextKey, m)
|
||||
c.SetContext(context.WithValue(c.Context(), sessionContextKey, session))
|
||||
c.SetContext(context.WithValue(c.Context(), middlewareContextKey, session))
|
||||
}
|
||||
|
||||
// saveSession handles session saving and error management after the response.
|
||||
|
@ -170,23 +171,25 @@ func releaseMiddleware(m *Middleware) {
|
|||
middlewarePool.Put(m)
|
||||
}
|
||||
|
||||
// FromContext returns the Middleware from the Fiber context.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: The Fiber context.
|
||||
//
|
||||
// Returns:
|
||||
// - *Middleware: The middleware object if found, otherwise nil.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// m := session.FromContext(c)
|
||||
func FromContext(c fiber.Ctx) *Middleware {
|
||||
m, ok := c.Locals(middlewareContextKey).(*Middleware)
|
||||
if !ok {
|
||||
return nil
|
||||
// FromContext returns the session from context.
|
||||
// Supported context types:
|
||||
// - fiber.Ctx: Retrieves session from Locals
|
||||
// - context.Context: Retrieves session from context values
|
||||
// If there is no session, nil is returned.
|
||||
func FromContext(c any) *Middleware {
|
||||
switch ctx := c.(type) {
|
||||
case fiber.Ctx:
|
||||
if m, ok := ctx.Locals(middlewareContextKey).(*Middleware); ok && m != nil {
|
||||
return m
|
||||
}
|
||||
case context.Context:
|
||||
if m, ok := ctx.Value(middlewareContextKey).(*Middleware); ok {
|
||||
return m
|
||||
}
|
||||
default:
|
||||
log.Errorf("Unsupported context type: %T. Expected fiber.Ctx or context.Context", c)
|
||||
}
|
||||
return m
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set sets a key-value pair in the session.
|
||||
|
|
|
@ -455,7 +455,7 @@ func Test_Session_Middleware_Store(t *testing.T) {
|
|||
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
func Test_Session_GoContext(t *testing.T) {
|
||||
func Test_Session_Context(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
|
@ -474,7 +474,7 @@ func Test_Session_GoContext(t *testing.T) {
|
|||
sess.Set("test_key", "test_value")
|
||||
|
||||
// Get the session from Go context
|
||||
goCtxSess := FromGoContext(c.Context())
|
||||
goCtxSess := FromContext(c.Context())
|
||||
|
||||
// Verify both sessions are the same
|
||||
if goCtxSess == nil {
|
||||
|
@ -506,11 +506,11 @@ func Test_Session_GoContext_EdgeCases(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
// Test with nil context
|
||||
sess := FromGoContext(nil) //nolint:staticcheck // Intentionally testing nil context behavior
|
||||
sess := FromContext(nil) //nolint:staticcheck // Intentionally testing nil context behavior
|
||||
require.Nil(t, sess, "Session should be nil with nil context")
|
||||
|
||||
// Test with context that doesn't have a session
|
||||
ctx := context.Background()
|
||||
sess = FromGoContext(ctx)
|
||||
sess = FromContext(ctx)
|
||||
require.Nil(t, sess, "Session should be nil when not in context")
|
||||
}
|
||||
|
|
|
@ -31,15 +31,6 @@ const (
|
|||
absExpirationKey absExpirationKeyType = iota
|
||||
)
|
||||
|
||||
// The contextKey type is unexported to prevent collisions with context keys defined in
|
||||
// other packages.
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
// sessionContextKey is the key used to store the *Session in the Go context.
|
||||
sessionContextKey contextKey = iota
|
||||
)
|
||||
|
||||
// Session pool for reusing byte buffers.
|
||||
var byteBufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
|
@ -522,14 +513,14 @@ func (s *Session) setAbsExpiration(absExpiration time.Time) {
|
|||
s.Set(absExpirationKey, absExpiration)
|
||||
}
|
||||
|
||||
// FromGoContext returns the Session from the go context.
|
||||
// If there is no session, nil is returned.
|
||||
func FromGoContext(ctx context.Context) *Session {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
// GetAndSetInContext Get the session and store it in the Go context
|
||||
func (s *Store) GetAndSetInContext(c fiber.Ctx) (*Session, error) {
|
||||
sess, err := s.Get(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sess, ok := ctx.Value(sessionContextKey).(*Session); ok {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
|
||||
// Add to Go context
|
||||
c.SetContext(context.WithValue(c.Context(), middlewareContextKey, sess))
|
||||
return sess, nil
|
||||
}
|
||||
|
|
|
@ -1316,3 +1316,177 @@ func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) {
|
|||
// Check that the error message is as expected
|
||||
require.ErrorContains(t, err, "failed to decode session data", "Unexpected error")
|
||||
}
|
||||
|
||||
// Test_Session_GetAndSetInContext tests the functionality of GetAndSetInContext method
|
||||
// which retrieves a session and stores it in the context.
|
||||
func Test_Session_GetAndSetInContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("new_session_creation_and_context_storage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup test environment
|
||||
store := NewStore()
|
||||
app := fiber.New()
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// Test creating a new session and storing in context
|
||||
sess, err := store.GetAndSetInContext(ctx)
|
||||
|
||||
// Verify operation success and session creation
|
||||
require.NoError(t, err, "should not error when creating new session")
|
||||
require.NotNil(t, sess, "session should not be nil")
|
||||
require.True(t, sess.Fresh(), "new session should be marked as fresh")
|
||||
|
||||
// Capture session ID for reference
|
||||
sessionID := sess.ID()
|
||||
require.NotEmpty(t, sessionID, "session ID should not be empty")
|
||||
|
||||
// Verify session was properly set in Go context
|
||||
// The session is stored directly in the context, not as middleware
|
||||
val := ctx.Context().Value(middlewareContextKey)
|
||||
sessionFromCtx, ok := val.(*Session)
|
||||
if !ok {
|
||||
t.Fatalf("session should be accessible from context")
|
||||
}
|
||||
require.Equal(t, sessionID, sessionFromCtx.ID(), "session ID should match between original and context-retrieved session")
|
||||
|
||||
// Properly save and release resources
|
||||
require.NoError(t, sess.Save(), "session save should succeed")
|
||||
sess.Release()
|
||||
})
|
||||
|
||||
t.Run("session_retrieval_and_data_persistence", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup test environment
|
||||
store := NewStore()
|
||||
app := fiber.New()
|
||||
|
||||
// Step 1: Create initial session
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
sess, err := store.GetAndSetInContext(ctx)
|
||||
require.NoError(t, err, "should create initial session without error")
|
||||
|
||||
// Set test data and save
|
||||
sessionID := sess.ID()
|
||||
sess.Set("testKey", "testValue")
|
||||
require.NoError(t, sess.Save(), "should save session without error")
|
||||
sess.Release()
|
||||
app.ReleaseCtx(ctx)
|
||||
|
||||
// Step 2: Retrieve session in a new context using cookie
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Request().Header.SetCookie(store.sessionName, sessionID)
|
||||
|
||||
// Get and verify session
|
||||
retrievedSess, err := store.GetAndSetInContext(ctx)
|
||||
defer retrievedSess.Release()
|
||||
|
||||
// Verify operation success and correct session retrieval
|
||||
require.NoError(t, err, "should retrieve existing session without error")
|
||||
require.NotNil(t, retrievedSess, "retrieved session should not be nil")
|
||||
require.False(t, retrievedSess.Fresh(), "retrieved session should not be fresh")
|
||||
require.Equal(t, sessionID, retrievedSess.ID(), "session ID should be preserved")
|
||||
|
||||
// Verify data persistence
|
||||
require.Equal(t, "testValue", retrievedSess.Get("testKey"), "session data should persist between requests")
|
||||
|
||||
// Verify session in context matches direct retrieval
|
||||
val := ctx.Context().Value(middlewareContextKey)
|
||||
sessionFromCtx, ok := val.(*Session)
|
||||
if !ok {
|
||||
t.Fatalf("session should be accessible from context")
|
||||
}
|
||||
require.NotNil(t, sessionFromCtx, "session should exist in context")
|
||||
require.Equal(t, sessionID, sessionFromCtx.ID(), "session IDs should match")
|
||||
require.Equal(t, "testValue", sessionFromCtx.Get("testKey"), "session data should be accessible from context")
|
||||
})
|
||||
|
||||
t.Run("session_data_modification", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup test environment
|
||||
store := NewStore()
|
||||
app := fiber.New()
|
||||
|
||||
// Create and save initial session
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
sess, err := store.GetAndSetInContext(ctx)
|
||||
require.NoError(t, err, "should create session without error")
|
||||
sessionID := sess.ID()
|
||||
require.NoError(t, sess.Save(), "should save session without error")
|
||||
sess.Release()
|
||||
app.ReleaseCtx(ctx)
|
||||
|
||||
// Retrieve and modify session
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Request().Header.SetCookie(store.sessionName, sessionID)
|
||||
|
||||
sess, err = store.GetAndSetInContext(ctx)
|
||||
require.NoError(t, err, "should retrieve session without error")
|
||||
|
||||
// Modify session data
|
||||
sess.Set("modifiedKey", "modifiedValue")
|
||||
require.NoError(t, sess.Save(), "should save modified session without error")
|
||||
|
||||
// Verify context has updated data immediately
|
||||
val := ctx.Context().Value(middlewareContextKey)
|
||||
sessionFromCtx, ok := val.(*Session)
|
||||
if !ok {
|
||||
t.Fatalf("session should be accessible from context")
|
||||
}
|
||||
require.Equal(t, "modifiedValue", sessionFromCtx.Get("modifiedKey"), "session in context should have updated data")
|
||||
|
||||
sess.Release()
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run Test_Session_GetAndSetInContext_Error
|
||||
func Test_Session_GetAndSetInContext_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a new store with mock storage that returns an error
|
||||
mockStore := NewStore(Config{
|
||||
Storage: &errorStorage{},
|
||||
})
|
||||
|
||||
// Create a new fiber instance
|
||||
app := fiber.New()
|
||||
|
||||
// Create a new fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// Attempt to get and set session in context
|
||||
sess, err := mockStore.GetAndSetInContext(ctx)
|
||||
require.Error(t, err, "should return error when storage fails")
|
||||
require.Nil(t, sess, "session should be nil on error")
|
||||
require.Contains(t, err.Error(), "mock error", "error should be from storage")
|
||||
}
|
||||
|
||||
// errorStorage implements Storage interface for error testing
|
||||
type errorStorage struct{}
|
||||
|
||||
func (*errorStorage) Get(string) ([]byte, error) {
|
||||
return nil, errors.New("mock error: Get failed")
|
||||
}
|
||||
|
||||
func (*errorStorage) Set(string, []byte, time.Duration) error {
|
||||
return errors.New("mock error: Set failed")
|
||||
}
|
||||
|
||||
func (*errorStorage) Delete(string) error {
|
||||
return errors.New("mock error: Delete failed")
|
||||
}
|
||||
|
||||
func (*errorStorage) Reset() error {
|
||||
return errors.New("mock error: Reset failed")
|
||||
}
|
||||
|
||||
func (*errorStorage) Close() error {
|
||||
return errors.New("mock error: Close failed")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue