diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 60f93abe..a5f08d25 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -88,6 +88,8 @@ func Test_CSRF_WithSession(t *testing.T) { // the session string is no longer be 123 newSessionIDString := sess.ID() + sess.Save() + app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config @@ -221,6 +223,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { // get session id newSessionIDString := sess.ID() + sess.Save() + app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 788a3c82..9b185c02 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -25,13 +25,23 @@ func Test_Session(t *testing.T) { ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) - // set session - ctx.Request().Header.SetCookie(store.sessionName, "123") - - // get session + // Get a new session sess, err := store.Get(ctx) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, sess.Fresh()) + token := sess.ID() + sess.Save() + + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + + // set session + ctx.Request().Header.SetCookie(store.sessionName, token) + + // get session + sess, err = store.Get(ctx) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, false, sess.Fresh()) // get keys keys := sess.Keys() @@ -64,12 +74,14 @@ func Test_Session(t *testing.T) { // get id id := sess.ID() - utils.AssertEqual(t, "123", id) + utils.AssertEqual(t, token, id) // save the old session first err = sess.Save() utils.AssertEqual(t, nil, err) + app.ReleaseCtx(ctx) + // requesting entirely new context to prevent falsy tests ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -108,7 +120,6 @@ func Test_Session_Types(t *testing.T) { // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.sessionName, "123") @@ -120,6 +131,10 @@ func Test_Session_Types(t *testing.T) { // the session string is no longer be 123 newSessionIDString := sess.ID() + + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString) type User struct { @@ -177,6 +192,11 @@ func Test_Session_Types(t *testing.T) { err = sess.Save() utils.AssertEqual(t, nil, err) + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + + ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString) + // get session sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) @@ -203,6 +223,7 @@ func Test_Session_Types(t *testing.T) { utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64)) utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64)) utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128)) + app.ReleaseCtx(ctx) } // go test -run Test_Session_Store_Reset @@ -214,7 +235,6 @@ func Test_Session_Store_Reset(t *testing.T) { app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) @@ -228,6 +248,12 @@ func Test_Session_Store_Reset(t *testing.T) { // reset store utils.AssertEqual(t, nil, store.Reset()) + id := sess.ID() + + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie(store.sessionName, id) // make sure the session is recreated sess, err = store.Get(ctx) @@ -302,6 +328,8 @@ func Test_Session_Save_Expiration(t *testing.T) { // set value sess.Set("name", "john") + token := sess.ID() + // expire this session in 5 seconds sess.SetExpiry(sessionDuration) @@ -309,7 +337,11 @@ func Test_Session_Save_Expiration(t *testing.T) { err = sess.Save() utils.AssertEqual(t, nil, err) + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + // here you need to get the old session yet + ctx.Request().Header.SetCookie(store.sessionName, token) sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "john", sess.Get("name")) @@ -317,10 +349,16 @@ func Test_Session_Save_Expiration(t *testing.T) { // just to make sure the session has been expired time.Sleep(sessionDuration + (10 * time.Millisecond)) + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + // here you should get a new session + ctx.Request().Header.SetCookie(store.sessionName, token) sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, sess.Get("name")) + utils.AssertEqual(t, true, sess.ID() != token) }) } @@ -364,7 +402,15 @@ func Test_Session_Destroy(t *testing.T) { // set value & save sess.Set("name", "fenny") + id := sess.ID() utils.AssertEqual(t, nil, sess.Save()) + + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // get session + ctx.Request().Header.Set(store.sessionName, id) sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) @@ -408,7 +454,8 @@ func Test_Session_Cookie(t *testing.T) { } // go test -run Test_Session_Cookie_In_Response -func Test_Session_Cookie_In_Response(t *testing.T) { +// Regression: https://github.com/gofiber/fiber/pull/1191 +func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { t.Parallel() store := New() app := fiber.New() @@ -421,6 +468,7 @@ func Test_Session_Cookie_In_Response(t *testing.T) { sess, err := store.Get(ctx) utils.AssertEqual(t, nil, err) sess.Set("id", "1") + id := sess.ID() utils.AssertEqual(t, true, sess.Fresh()) utils.AssertEqual(t, nil, sess.Save()) @@ -428,8 +476,9 @@ func Test_Session_Cookie_In_Response(t *testing.T) { utils.AssertEqual(t, nil, err) sess.Set("name", "john") utils.AssertEqual(t, true, sess.Fresh()) + utils.AssertEqual(t, id, sess.ID()) // session id should be the same - utils.AssertEqual(t, "1", sess.Get("id")) + utils.AssertEqual(t, sess.ID() != "1", true) utils.AssertEqual(t, "john", sess.Get("name")) } @@ -441,24 +490,31 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) sess, err := store.Get(ctx) utils.AssertEqual(t, nil, err) - ctx.Request().Header.SetCookie(store.sessionName, sess.ID()) - + id := sess.ID() sess.Set("id", "1") utils.AssertEqual(t, nil, sess.Save()) + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + ctx.Request().Header.SetCookie(store.sessionName, id) + sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) sess.Delete("id") utils.AssertEqual(t, nil, sess.Save()) + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + ctx.Request().Header.SetCookie(store.sessionName, id) + sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, false, sess.Fresh()) utils.AssertEqual(t, nil, sess.Get("id")) + app.ReleaseCtx(ctx) } // go test -run Test_Session_Reset @@ -475,6 +531,9 @@ func Test_Session_Reset(t *testing.T) { defer app.ReleaseCtx(ctx) t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) { + t.Parallel() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // a random session uuid originalSessionUUIDString := "" @@ -491,6 +550,9 @@ func Test_Session_Reset(t *testing.T) { err = freshSession.Save() utils.AssertEqual(t, nil, err) + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + // set cookie ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString) @@ -524,6 +586,8 @@ func Test_Session_Reset(t *testing.T) { // Check that the session id is not in the header or cookie anymore utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName))) utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName))) + + app.ReleaseCtx(ctx) }) } @@ -551,6 +615,12 @@ func Test_Session_Regenerate(t *testing.T) { err = freshSession.Save() utils.AssertEqual(t, nil, err) + // release the context + app.ReleaseCtx(ctx) + + // acquire a new context + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + // set cookie ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString) @@ -566,6 +636,9 @@ func Test_Session_Regenerate(t *testing.T) { // acquiredSession.fresh should be true after regenerating utils.AssertEqual(t, true, acquiredSession.Fresh()) + + // release the context + app.ReleaseCtx(ctx) }) } diff --git a/middleware/session/store.go b/middleware/session/store.go index f0c84f06..3b692f49 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -9,19 +9,27 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/internal/storage/memory" "github.com/gofiber/fiber/v2/utils" - - "github.com/valyala/fasthttp" ) // ErrEmptySessionID is an error that occurs when the session ID is empty. var ErrEmptySessionID = errors.New("session id cannot be empty") +// mux is a global mutex for session operations. +var mux sync.Mutex + +// sessionIDKey is the local key type used to store and retrieve the session ID in context. +type sessionIDKey int + +const ( + // sessionIDContextKey is the key used to store the session ID in the context locals. + sessionIDContextKey sessionIDKey = iota +) + type Store struct { Config } -var mux sync.Mutex - +// New creates a new session store with the provided configuration. func New(config ...Config) *Store { // Set default config cfg := configDefault(config...) @@ -35,31 +43,40 @@ func New(config ...Config) *Store { } } -// RegisterType will allow you to encode/decode custom types -// into any Storage provider +// RegisterType registers a custom type for encoding/decoding into any storage provider. func (*Store) RegisterType(i interface{}) { gob.Register(i) } -// Get will get/create a session +// Get retrieves or creates a session for the given context. func (s *Store) Get(c *fiber.Ctx) (*Session, error) { - var fresh bool - loadData := true + var rawData []byte + var err error - id := s.getSessionID(c) + id, ok := c.Locals(sessionIDContextKey).(string) + if !ok { + id = s.getSessionID(c) + } - if len(id) == 0 { - fresh = true - var err error - if id, err = s.responseCookies(c); err != nil { + fresh := ok // Assume the session is fresh if the ID is found in locals + + // Attempt to fetch session data if an ID is provided + if id != "" { + rawData, err = s.Storage.Get(id) + if err != nil { return nil, err } + if rawData == nil { + // Data not found, prepare to generate a new session + id = "" + } } - // If no key exist, create new one - if len(id) == 0 { - loadData = false + // Generate a new ID if needed + if id == "" { + fresh = true // The session is fresh if a new ID is generated id = s.KeyGenerator() + c.Locals(sessionIDContextKey, id) } // Create session object @@ -69,34 +86,17 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) { sess.id = id sess.fresh = fresh - // Fetch existing data - if loadData { - raw, err := s.Storage.Get(id) - // Unmarshal if we found data - if raw != nil && err == nil { - mux.Lock() - defer mux.Unlock() - _, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail - encCache := gob.NewDecoder(sess.byteBuffer) - err := encCache.Decode(&sess.data.Data) - if err != nil { - return nil, fmt.Errorf("failed to decode session data: %w", err) - } - } else if err != nil { - return nil, err - } else { - // both raw and err is nil, which means id is not in the storage - sess.fresh = true + // Decode session data if found + if rawData != nil { + if err := sess.decodeSessionData(rawData); err != nil { + return nil, fmt.Errorf("failed to decode session data: %w", err) } } return sess, nil } -// getSessionID will return the session id from: -// 1. cookie -// 2. http headers -// 3. query string +// getSessionID returns the session ID from cookies, headers, or query string. func (s *Store) getSessionID(c *fiber.Ctx) string { id := c.Cookies(s.sessionName) if len(id) > 0 { @@ -120,35 +120,27 @@ func (s *Store) getSessionID(c *fiber.Ctx) string { return "" } -func (s *Store) responseCookies(c *fiber.Ctx) (string, error) { - // Get key from response cookie - cookieValue := c.Response().Header.PeekCookie(s.sessionName) - if len(cookieValue) == 0 { - return "", nil - } - - cookie := fasthttp.AcquireCookie() - defer fasthttp.ReleaseCookie(cookie) - err := cookie.ParseBytes(cookieValue) - if err != nil { - return "", err - } - - value := make([]byte, len(cookie.Value())) - copy(value, cookie.Value()) - id := string(value) - return id, nil -} - -// Reset will delete all session from the storage +// Reset deletes all sessions from the storage. func (s *Store) Reset() error { return s.Storage.Reset() } -// Delete deletes a session by its id. +// Delete deletes a session by its ID. func (s *Store) Delete(id string) error { if id == "" { return ErrEmptySessionID } return s.Storage.Delete(id) } + +// decodeSessionData decodes the session data from raw bytes. +func (s *Session) decodeSessionData(rawData []byte) error { + mux.Lock() + defer mux.Unlock() + _, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail + encCache := gob.NewDecoder(s.byteBuffer) + if err := encCache.Decode(&s.data.Data); err != nil { + return fmt.Errorf("failed to decode session data: %w", err) + } + return nil +} diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index fbc6e846..d99868b9 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -64,12 +64,13 @@ func TestStore_getSessionID(t *testing.T) { // go test -run TestStore_Get // Regression: https://github.com/gofiber/fiber/issues/1408 +// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v func TestStore_Get(t *testing.T) { t.Parallel() unexpectedID := "test-session-id" // fiber instance app := fiber.New() - t.Run("session should persisted even session is invalid", func(t *testing.T) { + t.Run("session should be re-generated if it is invalid", func(t *testing.T) { t.Parallel() // session store store := New() @@ -82,7 +83,7 @@ func TestStore_Get(t *testing.T) { acquiredSession, err := store.Get(ctx) utils.AssertEqual(t, err, nil) - utils.AssertEqual(t, unexpectedID, acquiredSession.ID()) + utils.AssertEqual(t, acquiredSession.ID() != unexpectedID, true) }) }