mirror of https://github.com/gofiber/fiber.git
fix(middleware/session): mutex for thread safety (#3050)
* chore: Remove extra release and acquire ctx calls in session_test.go * feat: Remove unnecessary session mutex lock in decodeSessionData function * chore: Refactor session benchmark tests * fix(middleware/session): mutex for thread safety * feat: Add session mutex lock for thread safety * chore: Refactor releaseSession mutexpull/3067/head
parent
6fa0e7c9fc
commit
66a881441b
|
@ -14,6 +14,7 @@ import (
|
|||
)
|
||||
|
||||
type Session struct {
|
||||
mu sync.RWMutex // Mutex to protect non-data fields
|
||||
id string // session id
|
||||
fresh bool // if new session
|
||||
ctx *fiber.Ctx // fiber context
|
||||
|
@ -42,6 +43,7 @@ func acquireSession() *Session {
|
|||
}
|
||||
|
||||
func releaseSession(s *Session) {
|
||||
s.mu.Lock()
|
||||
s.id = ""
|
||||
s.exp = 0
|
||||
s.ctx = nil
|
||||
|
@ -52,16 +54,21 @@ func releaseSession(s *Session) {
|
|||
if s.byteBuffer != nil {
|
||||
s.byteBuffer.Reset()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
sessionPool.Put(s)
|
||||
}
|
||||
|
||||
// Fresh is true if the current session is new
|
||||
func (s *Session) Fresh() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.fresh
|
||||
}
|
||||
|
||||
// ID returns the session id
|
||||
func (s *Session) ID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.id
|
||||
}
|
||||
|
||||
|
@ -102,6 +109,9 @@ func (s *Session) Destroy() error {
|
|||
// Reset local data
|
||||
s.data.Reset()
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Use external Storage if exist
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
|
@ -114,6 +124,9 @@ func (s *Session) Destroy() error {
|
|||
|
||||
// Regenerate generates a new session id and delete the old one from Storage
|
||||
func (s *Session) Regenerate() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Delete old id from storage
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
|
@ -131,6 +144,10 @@ func (s *Session) Reset() error {
|
|||
if s.data != nil {
|
||||
s.data.Reset()
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Reset byte buffer
|
||||
if s.byteBuffer != nil {
|
||||
s.byteBuffer.Reset()
|
||||
|
@ -154,20 +171,24 @@ func (s *Session) Reset() error {
|
|||
|
||||
// refresh generates a new session, and set session.fresh to be true
|
||||
func (s *Session) refresh() {
|
||||
// Create a new id
|
||||
s.id = s.config.KeyGenerator()
|
||||
|
||||
// We assign a new id to the session, so the session must be fresh
|
||||
s.fresh = true
|
||||
}
|
||||
|
||||
// Save will update the storage and client cookie
|
||||
//
|
||||
// sess.Save() will save the session data to the storage and update the
|
||||
// client cookie, and it will release the session after saving.
|
||||
//
|
||||
// It's not safe to use the session after calling Save().
|
||||
func (s *Session) Save() error {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if session has your own expiration, otherwise use default value
|
||||
if s.exp <= 0 {
|
||||
s.exp = s.config.Expiration
|
||||
|
@ -177,25 +198,25 @@ func (s *Session) Save() error {
|
|||
s.setSession()
|
||||
|
||||
// Convert data to bytes
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
encCache := gob.NewEncoder(s.byteBuffer)
|
||||
err := encCache.Encode(&s.data.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode data: %w", err)
|
||||
}
|
||||
|
||||
// copy the data in buffer
|
||||
// Copy the data in buffer
|
||||
encodedBytes := make([]byte, s.byteBuffer.Len())
|
||||
copy(encodedBytes, s.byteBuffer.Bytes())
|
||||
|
||||
// pass copied bytes with session id to provider
|
||||
// Pass copied bytes with session id to provider
|
||||
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
|
||||
// Release session
|
||||
// TODO: It's not safe to use the Session after called Save()
|
||||
// TODO: It's not safe to use the Session after calling Save()
|
||||
releaseSession(s)
|
||||
|
||||
return nil
|
||||
|
@ -211,6 +232,8 @@ func (s *Session) Keys() []string {
|
|||
|
||||
// SetExpiry sets a specific expiration for this session
|
||||
func (s *Session) SetExpiry(exp time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.exp = exp
|
||||
}
|
||||
|
||||
|
@ -276,3 +299,13 @@ func (s *Session) delSession() {
|
|||
fasthttp.ReleaseCookie(fcookie)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSessionData decodes the session data from raw bytes.
|
||||
func (s *Session) decodeSessionData(rawData []byte) error {
|
||||
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
|
||||
encCache := gob.NewDecoder(s.byteBuffer)
|
||||
if err := encCache.Decode(&s.data.Data); err != nil {
|
||||
return fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -673,3 +675,230 @@ func Benchmark_Session(b *testing.B) {
|
|||
utils.AssertEqual(b, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
|
||||
func Benchmark_Session_Parallel(b *testing.B) {
|
||||
b.Run("default", func(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
store := New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
|
||||
func Benchmark_Session_Asserted(b *testing.B) {
|
||||
b.Run("default", func(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(b, nil, err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
store := New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(b, nil, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
|
||||
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
|
||||
b.Run("default", func(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
utils.AssertEqual(b, nil, sess.Save())
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
store := New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
utils.AssertEqual(b, nil, sess.Save())
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -race -run Test_Session_Concurrency ./...
|
||||
func Test_Session_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
store := New()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 10) // Buffered channel to collect errors
|
||||
const numGoroutines = 10 // Number of concurrent goroutines to test
|
||||
|
||||
// Start numGoroutines goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
sess, err := store.Get(localCtx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Set a value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// get the session id
|
||||
id := sess.ID()
|
||||
|
||||
// Check if the session is fresh
|
||||
if !sess.Fresh() {
|
||||
errChan <- errors.New("session should be fresh")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the session
|
||||
if err := sess.Save(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Release the context
|
||||
app.ReleaseCtx(localCtx)
|
||||
|
||||
// Acquire a new context
|
||||
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(localCtx)
|
||||
|
||||
// Set the session id in the header
|
||||
localCtx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
// Get the session
|
||||
sess, err = store.Get(localCtx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get the value
|
||||
name := sess.Get("name")
|
||||
if name != "john" {
|
||||
errChan <- errors.New("name should be john")
|
||||
return
|
||||
}
|
||||
|
||||
// Get ID from the session
|
||||
if sess.ID() != id {
|
||||
errChan <- errors.New("id should be the same")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the session is fresh
|
||||
if sess.Fresh() {
|
||||
errChan <- errors.New("session should not be fresh")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the key
|
||||
sess.Delete("name")
|
||||
|
||||
// Get the value
|
||||
name = sess.Get("name")
|
||||
if name != nil {
|
||||
errChan <- errors.New("name should be nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
if err := sess.Destroy(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait() // Wait for all goroutines to finish
|
||||
close(errChan) // Close the channel to signal no more errors will be sent
|
||||
|
||||
// Check for errors sent to errChan
|
||||
for err := range errChan {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
|
@ -14,9 +13,6 @@ import (
|
|||
// ErrEmptySessionID is an error that occurs when the session ID is empty.
|
||||
var ErrEmptySessionID = errors.New("session id cannot be empty")
|
||||
|
||||
// mux is a global mutex for session operations.
|
||||
var mux sync.Mutex
|
||||
|
||||
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
|
||||
type sessionIDKey int
|
||||
|
||||
|
@ -81,6 +77,10 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
|||
|
||||
// Create session object
|
||||
sess := acquireSession()
|
||||
|
||||
sess.mu.Lock()
|
||||
defer sess.mu.Unlock()
|
||||
|
||||
sess.ctx = c
|
||||
sess.config = s
|
||||
sess.id = id
|
||||
|
@ -88,6 +88,8 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
|||
|
||||
// Decode session data if found
|
||||
if rawData != nil {
|
||||
sess.data.Lock()
|
||||
defer sess.data.Unlock()
|
||||
if err := sess.decodeSessionData(rawData); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
|
@ -132,15 +134,3 @@ func (s *Store) Delete(id string) error {
|
|||
}
|
||||
return s.Storage.Delete(id)
|
||||
}
|
||||
|
||||
// decodeSessionData decodes the session data from raw bytes.
|
||||
func (s *Session) decodeSessionData(rawData []byte) error {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
|
||||
encCache := gob.NewDecoder(s.byteBuffer)
|
||||
if err := encCache.Decode(&s.data.Data); err != nil {
|
||||
return fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue