mirror of https://github.com/gofiber/fiber.git
fix(middleware/session): mutex for thread safety (#3049)
* fix(middleware/session): mutex for thread safety * chore: Remove extra release and acquire ctx calls in session_test.go * feat: Remove unnecessary session mutex lock in decodeSessionData functionpull/3017/head^2
parent
dbba6cfa69
commit
83731cef85
|
@ -13,6 +13,7 @@ import (
|
|||
)
|
||||
|
||||
type Session struct {
|
||||
mu sync.RWMutex // Mutex to protect non-data fields
|
||||
id string // session id
|
||||
fresh bool // if new session
|
||||
ctx fiber.Ctx // fiber context
|
||||
|
@ -56,11 +57,15 @@ func releaseSession(s *Session) {
|
|||
|
||||
// Fresh is true if the current session is new
|
||||
func (s *Session) Fresh() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.fresh
|
||||
}
|
||||
|
||||
// ID returns the session id
|
||||
func (s *Session) ID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.id
|
||||
}
|
||||
|
||||
|
@ -101,6 +106,9 @@ func (s *Session) Destroy() error {
|
|||
// Reset local data
|
||||
s.data.Reset()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Use external Storage if exist
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
|
@ -113,6 +121,9 @@ func (s *Session) Destroy() error {
|
|||
|
||||
// Regenerate generates a new session id and delete the old one from Storage
|
||||
func (s *Session) Regenerate() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Delete old id from storage
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
|
@ -137,6 +148,9 @@ func (s *Session) Reset() error {
|
|||
// Reset expiration
|
||||
s.exp = 0
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Delete old id from storage
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
|
@ -153,10 +167,7 @@ func (s *Session) Reset() error {
|
|||
|
||||
// refresh generates a new session, and set session.fresh to be true
|
||||
func (s *Session) refresh() {
|
||||
// Create a new id
|
||||
s.id = s.config.KeyGenerator()
|
||||
|
||||
// We assign a new id to the session, so the session must be fresh
|
||||
s.fresh = true
|
||||
}
|
||||
|
||||
|
@ -167,6 +178,9 @@ func (s *Session) Save() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if session has your own expiration, otherwise use default value
|
||||
if s.exp <= 0 {
|
||||
s.exp = s.config.Expiration
|
||||
|
@ -176,25 +190,23 @@ func (s *Session) Save() error {
|
|||
s.setSession()
|
||||
|
||||
// Convert data to bytes
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
encCache := gob.NewEncoder(s.byteBuffer)
|
||||
err := encCache.Encode(&s.data.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode data: %w", err)
|
||||
}
|
||||
|
||||
// copy the data in buffer
|
||||
// Copy the data in buffer
|
||||
encodedBytes := make([]byte, s.byteBuffer.Len())
|
||||
copy(encodedBytes, s.byteBuffer.Bytes())
|
||||
|
||||
// pass copied bytes with session id to provider
|
||||
// Pass copied bytes with session id to provider
|
||||
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Release session
|
||||
// TODO: It's not safe to use the Session after called Save()
|
||||
// TODO: It's not safe to use the Session after calling Save()
|
||||
releaseSession(s)
|
||||
|
||||
return nil
|
||||
|
@ -210,6 +222,8 @@ func (s *Session) Keys() []string {
|
|||
|
||||
// SetExpiry sets a specific expiration for this session
|
||||
func (s *Session) SetExpiry(exp time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.exp = exp
|
||||
}
|
||||
|
||||
|
@ -275,3 +289,13 @@ func (s *Session) delSession() {
|
|||
fasthttp.ReleaseCookie(fcookie)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSessionData decodes the session data from raw bytes.
|
||||
func (s *Session) decodeSessionData(rawData []byte) error {
|
||||
_, _ = s.byteBuffer.Write(rawData)
|
||||
encCache := gob.NewDecoder(s.byteBuffer)
|
||||
if err := encCache.Decode(&s.data.Data); err != nil {
|
||||
return fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -856,3 +858,108 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -race -run Test_Session_Concurrency ./...
|
||||
func Test_Session_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
store := New()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 10) // Buffered channel to collect errors
|
||||
const numGoroutines = 10 // Number of concurrent goroutines to test
|
||||
|
||||
// Start numGoroutines goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
sess, err := store.Get(localCtx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Set a value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// get the session id
|
||||
id := sess.ID()
|
||||
|
||||
// Check if the session is fresh
|
||||
if !sess.Fresh() {
|
||||
errChan <- errors.New("session should be fresh")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the session
|
||||
if err := sess.Save(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Release the context
|
||||
app.ReleaseCtx(localCtx)
|
||||
|
||||
// Acquire a new context
|
||||
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(localCtx)
|
||||
|
||||
// Set the session id in the header
|
||||
localCtx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
// Get the session
|
||||
sess, err = store.Get(localCtx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get the value
|
||||
name := sess.Get("name")
|
||||
if name != "john" {
|
||||
errChan <- errors.New("name should be john")
|
||||
return
|
||||
}
|
||||
|
||||
// Get ID from the session
|
||||
if sess.ID() != id {
|
||||
errChan <- errors.New("id should be the same")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the session is fresh
|
||||
if sess.Fresh() {
|
||||
errChan <- errors.New("session should not be fresh")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the key
|
||||
sess.Delete("name")
|
||||
|
||||
// Get the value
|
||||
name = sess.Get("name")
|
||||
if name != nil {
|
||||
errChan <- errors.New("name should be nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
if err := sess.Destroy(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait() // Wait for all goroutines to finish
|
||||
close(errChan) // Close the channel to signal no more errors will be sent
|
||||
|
||||
// Check for errors sent to errChan
|
||||
for err := range errChan {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/gofiber/fiber/v3/internal/storage/memory"
|
||||
|
@ -14,9 +13,6 @@ import (
|
|||
// ErrEmptySessionID is an error that occurs when the session ID is empty.
|
||||
var ErrEmptySessionID = errors.New("session id cannot be empty")
|
||||
|
||||
// mux is a global mutex for session operations.
|
||||
var mux sync.Mutex
|
||||
|
||||
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
|
||||
type sessionIDKey int
|
||||
|
||||
|
@ -132,15 +128,3 @@ func (s *Store) Delete(id string) error {
|
|||
}
|
||||
return s.Storage.Delete(id)
|
||||
}
|
||||
|
||||
// decodeSessionData decodes the session data from raw bytes.
|
||||
func (s *Session) decodeSessionData(rawData []byte) error {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
_, _ = s.byteBuffer.Write(rawData)
|
||||
encCache := gob.NewDecoder(s.byteBuffer)
|
||||
if err := encCache.Decode(&s.data.Data); err != nil {
|
||||
return fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue