From 83731cef85c130eba7c0b1997318fde86b11f22f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 29 Jun 2024 16:47:09 -0300 Subject: [PATCH] fix(middleware/session): mutex for thread safety (#3049) * fix(middleware/session): mutex for thread safety * chore: Remove extra release and acquire ctx calls in session_test.go * feat: Remove unnecessary session mutex lock in decodeSessionData function --- middleware/session/session.go | 40 ++++++++--- middleware/session/session_test.go | 107 +++++++++++++++++++++++++++++ middleware/session/store.go | 16 ----- 3 files changed, 139 insertions(+), 24 deletions(-) diff --git a/middleware/session/session.go b/middleware/session/session.go index c2573439..9ab5401c 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -13,6 +13,7 @@ import ( ) type Session struct { + mu sync.RWMutex // Mutex to protect non-data fields id string // session id fresh bool // if new session ctx fiber.Ctx // fiber context @@ -56,11 +57,15 @@ func releaseSession(s *Session) { // Fresh is true if the current session is new func (s *Session) Fresh() bool { + s.mu.RLock() + defer s.mu.RUnlock() return s.fresh } // ID returns the session id func (s *Session) ID() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.id } @@ -101,6 +106,9 @@ func (s *Session) Destroy() error { // Reset local data s.data.Reset() + s.mu.Lock() + defer s.mu.Unlock() + // Use external Storage if exist if err := s.config.Storage.Delete(s.id); err != nil { return err @@ -113,6 +121,9 @@ func (s *Session) Destroy() error { // Regenerate generates a new session id and delete the old one from Storage func (s *Session) Regenerate() error { + s.mu.Lock() + defer s.mu.Unlock() + // Delete old id from storage if err := s.config.Storage.Delete(s.id); err != nil { return err @@ -137,6 +148,9 @@ func (s *Session) Reset() error { // Reset expiration s.exp = 0 + s.mu.Lock() + defer s.mu.Unlock() + // Delete old id from storage if err := s.config.Storage.Delete(s.id); err != nil { return err @@ -153,10 +167,7 @@ func (s *Session) Reset() error { // refresh generates a new session, and set session.fresh to be true func (s *Session) refresh() { - // Create a new id s.id = s.config.KeyGenerator() - - // We assign a new id to the session, so the session must be fresh s.fresh = true } @@ -167,6 +178,9 @@ func (s *Session) Save() error { return nil } + s.mu.Lock() + defer s.mu.Unlock() + // Check if session has your own expiration, otherwise use default value if s.exp <= 0 { s.exp = s.config.Expiration @@ -176,25 +190,23 @@ func (s *Session) Save() error { s.setSession() // Convert data to bytes - mux.Lock() - defer mux.Unlock() encCache := gob.NewEncoder(s.byteBuffer) err := encCache.Encode(&s.data.Data) if err != nil { return fmt.Errorf("failed to encode data: %w", err) } - // copy the data in buffer + // Copy the data in buffer encodedBytes := make([]byte, s.byteBuffer.Len()) copy(encodedBytes, s.byteBuffer.Bytes()) - // pass copied bytes with session id to provider + // Pass copied bytes with session id to provider if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil { return err } // Release session - // TODO: It's not safe to use the Session after called Save() + // TODO: It's not safe to use the Session after calling Save() releaseSession(s) return nil @@ -210,6 +222,8 @@ func (s *Session) Keys() []string { // SetExpiry sets a specific expiration for this session func (s *Session) SetExpiry(exp time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() s.exp = exp } @@ -275,3 +289,13 @@ func (s *Session) delSession() { fasthttp.ReleaseCookie(fcookie) } } + +// decodeSessionData decodes the session data from raw bytes. +func (s *Session) decodeSessionData(rawData []byte) error { + _, _ = s.byteBuffer.Write(rawData) + encCache := gob.NewDecoder(s.byteBuffer) + if err := encCache.Decode(&s.data.Data); err != nil { + return fmt.Errorf("failed to decode session data: %w", err) + } + return nil +} diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index aa7551eb..fa12d690 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -1,6 +1,8 @@ package session import ( + "errors" + "sync" "testing" "time" @@ -856,3 +858,108 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { }) }) } + +// go test -v -race -run Test_Session_Concurrency ./... +func Test_Session_Concurrency(t *testing.T) { + t.Parallel() + app := fiber.New() + store := New() + + var wg sync.WaitGroup + errChan := make(chan error, 10) // Buffered channel to collect errors + const numGoroutines = 10 // Number of concurrent goroutines to test + + // Start numGoroutines goroutines + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + localCtx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + sess, err := store.Get(localCtx) + if err != nil { + errChan <- err + return + } + + // Set a value + sess.Set("name", "john") + + // get the session id + id := sess.ID() + + // Check if the session is fresh + if !sess.Fresh() { + errChan <- errors.New("session should be fresh") + return + } + + // Save the session + if err := sess.Save(); err != nil { + errChan <- err + return + } + + // Release the context + app.ReleaseCtx(localCtx) + + // Acquire a new context + localCtx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(localCtx) + + // Set the session id in the header + localCtx.Request().Header.SetCookie(store.sessionName, id) + + // Get the session + sess, err = store.Get(localCtx) + if err != nil { + errChan <- err + return + } + + // Get the value + name := sess.Get("name") + if name != "john" { + errChan <- errors.New("name should be john") + return + } + + // Get ID from the session + if sess.ID() != id { + errChan <- errors.New("id should be the same") + return + } + + // Check if the session is fresh + if sess.Fresh() { + errChan <- errors.New("session should not be fresh") + return + } + + // Delete the key + sess.Delete("name") + + // Get the value + name = sess.Get("name") + if name != nil { + errChan <- errors.New("name should be nil") + return + } + + // Destroy the session + if err := sess.Destroy(); err != nil { + errChan <- err + return + } + }() + } + + wg.Wait() // Wait for all goroutines to finish + close(errChan) // Close the channel to signal no more errors will be sent + + // Check for errors sent to errChan + for err := range errChan { + require.NoError(t, err) + } +} diff --git a/middleware/session/store.go b/middleware/session/store.go index 09f8da8e..249e8f53 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -4,7 +4,6 @@ import ( "encoding/gob" "errors" "fmt" - "sync" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" @@ -14,9 +13,6 @@ import ( // ErrEmptySessionID is an error that occurs when the session ID is empty. var ErrEmptySessionID = errors.New("session id cannot be empty") -// mux is a global mutex for session operations. -var mux sync.Mutex - // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int @@ -132,15 +128,3 @@ func (s *Store) Delete(id string) error { } return s.Storage.Delete(id) } - -// decodeSessionData decodes the session data from raw bytes. -func (s *Session) decodeSessionData(rawData []byte) error { - mux.Lock() - defer mux.Unlock() - _, _ = s.byteBuffer.Write(rawData) - encCache := gob.NewDecoder(s.byteBuffer) - if err := encCache.Decode(&s.data.Data); err != nil { - return fmt.Errorf("failed to decode session data: %w", err) - } - return nil -}