diff --git a/middleware/session/session.go b/middleware/session/session.go index 9ab5401c..7272ba18 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -42,6 +42,7 @@ func acquireSession() *Session { } func releaseSession(s *Session) { + s.mu.Lock() s.id = "" s.exp = 0 s.ctx = nil @@ -52,6 +53,7 @@ func releaseSession(s *Session) { if s.byteBuffer != nil { s.byteBuffer.Reset() } + s.mu.Unlock() sessionPool.Put(s) } @@ -106,8 +108,8 @@ func (s *Session) Destroy() error { // Reset local data s.data.Reset() - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() // Use external Storage if exist if err := s.config.Storage.Delete(s.id); err != nil { @@ -141,6 +143,10 @@ func (s *Session) Reset() error { if s.data != nil { s.data.Reset() } + + s.mu.Lock() + defer s.mu.Unlock() + // Reset byte buffer if s.byteBuffer != nil { s.byteBuffer.Reset() @@ -148,9 +154,6 @@ func (s *Session) Reset() error { // Reset expiration s.exp = 0 - s.mu.Lock() - defer s.mu.Unlock() - // Delete old id from storage if err := s.config.Storage.Delete(s.id); err != nil { return err @@ -172,6 +175,11 @@ func (s *Session) refresh() { } // Save will update the storage and client cookie +// +// sess.Save() will save the session data to the storage and update the +// client cookie, and it will release the session after saving. +// +// It's not safe to use the session after calling Save(). func (s *Session) Save() error { // Better safe than sorry if s.data == nil { @@ -179,7 +187,6 @@ func (s *Session) Save() error { } s.mu.Lock() - defer s.mu.Unlock() // Check if session has your own expiration, otherwise use default value if s.exp <= 0 { @@ -205,6 +212,8 @@ func (s *Session) Save() error { return err } + s.mu.Unlock() + // Release session // TODO: It's not safe to use the Session after calling Save() releaseSession(s) diff --git a/middleware/session/store.go b/middleware/session/store.go index 249e8f53..05fba8e2 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -77,6 +77,10 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { // Create session object sess := acquireSession() + + sess.mu.Lock() + defer sess.mu.Unlock() + sess.ctx = c sess.config = s sess.id = id @@ -84,6 +88,8 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { // Decode session data if found if rawData != nil { + sess.data.Lock() + defer sess.data.Unlock() if err := sess.decodeSessionData(rawData); err != nil { return nil, fmt.Errorf("failed to decode session data: %w", err) }