fix(middleware/session): mutex for thread safety (#3049)

* fix(middleware/session): mutex for thread safety

* chore: Remove extra release and acquire ctx calls in session_test.go

* feat: Remove unnecessary session mutex lock in decodeSessionData function
pull/3017/head^2
Jason McNeil 2024-06-29 16:47:09 -03:00 committed by GitHub
parent dbba6cfa69
commit 83731cef85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 139 additions and 24 deletions

View File

@ -13,6 +13,7 @@ import (
)
type Session struct {
mu sync.RWMutex // Mutex to protect non-data fields
id string // session id
fresh bool // if new session
ctx fiber.Ctx // fiber context
@ -56,11 +57,15 @@ func releaseSession(s *Session) {
// Fresh is true if the current session is new
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}
// ID returns the session id
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
@ -101,6 +106,9 @@ func (s *Session) Destroy() error {
// Reset local data
s.data.Reset()
s.mu.Lock()
defer s.mu.Unlock()
// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
@ -113,6 +121,9 @@ func (s *Session) Destroy() error {
// Regenerate generates a new session id and delete the old one from Storage
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()
// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
@ -137,6 +148,9 @@ func (s *Session) Reset() error {
// Reset expiration
s.exp = 0
s.mu.Lock()
defer s.mu.Unlock()
// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
@ -153,10 +167,7 @@ func (s *Session) Reset() error {
// refresh generates a new session, and set session.fresh to be true
func (s *Session) refresh() {
// Create a new id
s.id = s.config.KeyGenerator()
// We assign a new id to the session, so the session must be fresh
s.fresh = true
}
@ -167,6 +178,9 @@ func (s *Session) Save() error {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
@ -176,25 +190,23 @@ func (s *Session) Save() error {
s.setSession()
// Convert data to bytes
mux.Lock()
defer mux.Unlock()
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
// copy the data in buffer
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())
// pass copied bytes with session id to provider
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}
// Release session
// TODO: It's not safe to use the Session after called Save()
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)
return nil
@ -210,6 +222,8 @@ func (s *Session) Keys() []string {
// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
}
@ -275,3 +289,13 @@ func (s *Session) delSession() {
fasthttp.ReleaseCookie(fcookie)
}
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}

View File

@ -1,6 +1,8 @@
package session
import (
"errors"
"sync"
"testing"
"time"
@ -856,3 +858,108 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
})
})
}
// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()
var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test
// Start numGoroutines goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.Get(localCtx)
if err != nil {
errChan <- err
return
}
// Set a value
sess.Set("name", "john")
// get the session id
id := sess.ID()
// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}
// Save the session
if err := sess.Save(); err != nil {
errChan <- err
return
}
// Release the context
app.ReleaseCtx(localCtx)
// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)
// Set the session id in the header
localCtx.Request().Header.SetCookie(store.sessionName, id)
// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}
// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}
// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}
// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}
// Delete the key
sess.Delete("name")
// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}
// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
}()
}
wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent
// Check for errors sent to errChan
for err := range errChan {
require.NoError(t, err)
}
}

View File

@ -4,7 +4,6 @@ import (
"encoding/gob"
"errors"
"fmt"
"sync"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
@ -14,9 +13,6 @@ import (
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")
// mux is a global mutex for session operations.
var mux sync.Mutex
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
@ -132,15 +128,3 @@ func (s *Store) Delete(id string) error {
}
return s.Storage.Delete(id)
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
mux.Lock()
defer mux.Unlock()
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}