fix(middleware/session): mutex for thread safety (#3050)

* chore: Remove extra release and acquire ctx calls in session_test.go

* feat: Remove unnecessary session mutex lock in decodeSessionData function

* chore: Refactor session benchmark tests

* fix(middleware/session): mutex for thread safety

* feat: Add session mutex lock for thread safety

* chore: Refactor releaseSession mutex
pull/3067/head
Jason McNeil 2024-06-30 16:16:23 -03:00 committed by GitHub
parent 6fa0e7c9fc
commit 66a881441b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 276 additions and 24 deletions

View File

@ -14,6 +14,7 @@ import (
)
type Session struct {
mu sync.RWMutex // Mutex to protect non-data fields
id string // session id
fresh bool // if new session
ctx *fiber.Ctx // fiber context
@ -42,6 +43,7 @@ func acquireSession() *Session {
}
func releaseSession(s *Session) {
s.mu.Lock()
s.id = ""
s.exp = 0
s.ctx = nil
@ -52,16 +54,21 @@ func releaseSession(s *Session) {
if s.byteBuffer != nil {
s.byteBuffer.Reset()
}
s.mu.Unlock()
sessionPool.Put(s)
}
// Fresh is true if the current session is new
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}
// ID returns the session id
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
@ -102,6 +109,9 @@ func (s *Session) Destroy() error {
// Reset local data
s.data.Reset()
s.mu.RLock()
defer s.mu.RUnlock()
// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
@ -114,6 +124,9 @@ func (s *Session) Destroy() error {
// Regenerate generates a new session id and delete the old one from Storage
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()
// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
@ -131,6 +144,10 @@ func (s *Session) Reset() error {
if s.data != nil {
s.data.Reset()
}
s.mu.Lock()
defer s.mu.Unlock()
// Reset byte buffer
if s.byteBuffer != nil {
s.byteBuffer.Reset()
@ -154,20 +171,24 @@ func (s *Session) Reset() error {
// refresh generates a new session, and set session.fresh to be true
func (s *Session) refresh() {
// Create a new id
s.id = s.config.KeyGenerator()
// We assign a new id to the session, so the session must be fresh
s.fresh = true
}
// Save will update the storage and client cookie
//
// sess.Save() will save the session data to the storage and update the
// client cookie, and it will release the session after saving.
//
// It's not safe to use the session after calling Save().
func (s *Session) Save() error {
// Better safe than sorry
if s.data == nil {
return nil
}
s.mu.Lock()
// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
@ -177,25 +198,25 @@ func (s *Session) Save() error {
s.setSession()
// Convert data to bytes
mux.Lock()
defer mux.Unlock()
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
// copy the data in buffer
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())
// pass copied bytes with session id to provider
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}
s.mu.Unlock()
// Release session
// TODO: It's not safe to use the Session after called Save()
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)
return nil
@ -211,6 +232,8 @@ func (s *Session) Keys() []string {
// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
}
@ -276,3 +299,13 @@ func (s *Session) delSession() {
fasthttp.ReleaseCookie(fcookie)
}
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}

View File

@ -1,6 +1,8 @@
package session
import (
"errors"
"sync"
"testing"
"time"
@ -673,3 +675,230 @@ func Benchmark_Session(b *testing.B) {
utils.AssertEqual(b, nil, err)
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
func Benchmark_Session_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
app.ReleaseCtx(c)
}
})
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
app.ReleaseCtx(c)
}
})
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
func Benchmark_Session_Asserted(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
err = sess.Save()
utils.AssertEqual(b, nil, err)
}
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
err = sess.Save()
utils.AssertEqual(b, nil, err)
}
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
utils.AssertEqual(b, nil, sess.Save())
app.ReleaseCtx(c)
}
})
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
utils.AssertEqual(b, nil, sess.Save())
app.ReleaseCtx(c)
}
})
})
}
// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()
var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test
// Start numGoroutines goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.Get(localCtx)
if err != nil {
errChan <- err
return
}
// Set a value
sess.Set("name", "john")
// get the session id
id := sess.ID()
// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}
// Save the session
if err := sess.Save(); err != nil {
errChan <- err
return
}
// Release the context
app.ReleaseCtx(localCtx)
// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)
// Set the session id in the header
localCtx.Request().Header.SetCookie(store.sessionName, id)
// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}
// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}
// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}
// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}
// Delete the key
sess.Delete("name")
// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}
// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
}()
}
wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent
// Check for errors sent to errChan
for err := range errChan {
utils.AssertEqual(t, nil, err)
}
}

View File

@ -4,7 +4,6 @@ import (
"encoding/gob"
"errors"
"fmt"
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
@ -14,9 +13,6 @@ import (
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")
// mux is a global mutex for session operations.
var mux sync.Mutex
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
@ -81,6 +77,10 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
// Create session object
sess := acquireSession()
sess.mu.Lock()
defer sess.mu.Unlock()
sess.ctx = c
sess.config = s
sess.id = id
@ -88,6 +88,8 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
// Decode session data if found
if rawData != nil {
sess.data.Lock()
defer sess.data.Unlock()
if err := sess.decodeSessionData(rawData); err != nil {
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
@ -132,15 +134,3 @@ func (s *Store) Delete(id string) error {
}
return s.Storage.Delete(id)
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
mux.Lock()
defer mux.Unlock()
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}