Merge pull request from GHSA-98j2-3j3p-fw2v

* fix: token injection vulnerability GHSA-98j2-3j3p-fw2v

- Ensure session IDs are securely generated server-side.
- Add validation to prevent user-supplied session IDs.
- Update tests to verify correct session token use.

This update addresses the critical session middleware vulnerability identified in versions 2 and above of GoFiber.

* chore: Remove unused code and dependencies in session store

* test(middleware/csrf): Save session after generating new session ID

This commit saves the session after generating a new session ID to ensure that the updated session ID is persisted. This change is necessary to address a critical session middleware vulnerability identified in versions 2 and above of GoFiber.

* chore: Save session ID in context for middleware chain

The code changes add functionality to save the newly generated session ID in the context, allowing it to be accessible to subsequent middlewares in the chain. This improvement ensures that the session ID is available for use throughout the middleware stack.

* refactor(session.go): general clean-up

* chore: Revert session freshness behavior

The code changes in `session_test.go` fix the session freshness check by updating the assertions for `sess.Fresh()` and `sess.ID()`. The previous assertions were incorrect and have been corrected to ensure the session ID remains the same and the session is not fresh.

* chore: Update session.Get method signature to use fiber.Ctx instead of *fiber.Ctx
This commit is contained in:
Jason McNeil 2024-06-26 04:17:41 -03:00 committed by GitHub
parent ba6ea675ba
commit b53802a5cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 146 additions and 76 deletions

View File

@ -88,6 +88,8 @@ func Test_CSRF_WithSession(t *testing.T) {
// the session string is no longer be 123
newSessionIDString := sess.ID()
sess.Save()
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config
@ -221,6 +223,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
// get session id
newSessionIDString := sess.ID()
sess.Save()
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config
@ -1089,6 +1093,8 @@ func Test_CSRF_DeleteToken_WithSession(t *testing.T) {
// the session string is no longer be 123
newSessionIDString := sess.ID()
sess.Save()
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config

View File

@ -24,13 +24,23 @@ func Test_Session(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set session
ctx.Request().Header.SetCookie(store.sessionName, "123")
// get session
// Get a new session
sess, err := store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
token := sess.ID()
sess.Save()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set session
ctx.Request().Header.SetCookie(store.sessionName, token)
// get session
sess, err = store.Get(ctx)
require.NoError(t, err)
require.False(t, sess.Fresh())
// get keys
keys := sess.Keys()
@ -63,12 +73,14 @@ func Test_Session(t *testing.T) {
// get id
id := sess.ID()
require.Equal(t, "123", id)
require.Equal(t, token, id)
// save the old session first
err = sess.Save()
require.NoError(t, err)
app.ReleaseCtx(ctx)
// requesting entirely new context to prevent falsy tests
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
@ -105,7 +117,6 @@ func Test_Session_Types(t *testing.T) {
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, "123")
@ -117,6 +128,10 @@ func Test_Session_Types(t *testing.T) {
// the session string is no longer be 123
newSessionIDString := sess.ID()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
type User struct {
@ -174,6 +189,11 @@ func Test_Session_Types(t *testing.T) {
err = sess.Save()
require.NoError(t, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
// get session
sess, err = store.Get(ctx)
require.NoError(t, err)
@ -259,6 +279,8 @@ func Test_Session_Types(t *testing.T) {
vcomplex128Result, ok := sess.Get("vcomplex128").(complex128)
require.True(t, ok)
require.Equal(t, vcomplex128, vcomplex128Result)
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Store_Reset
@ -270,7 +292,6 @@ func Test_Session_Store_Reset(t *testing.T) {
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
@ -284,6 +305,12 @@ func Test_Session_Store_Reset(t *testing.T) {
// reset store
require.NoError(t, store.Reset())
id := sess.ID()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie(store.sessionName, id)
// make sure the session is recreated
sess, err = store.Get(ctx)
@ -364,6 +391,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
// set value
sess.Set("name", "john")
token := sess.ID()
// expire this session in 5 seconds
sess.SetExpiry(sessionDuration)
@ -371,7 +400,11 @@ func Test_Session_Save_Expiration(t *testing.T) {
err = sess.Save()
require.NoError(t, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you need to get the old session yet
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Equal(t, "john", sess.Get("name"))
@ -379,10 +412,16 @@ func Test_Session_Save_Expiration(t *testing.T) {
// just to make sure the session has been expired
time.Sleep(sessionDuration + (10 * time.Millisecond))
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// here you should get a new session
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Nil(t, sess.Get("name"))
require.NotEqual(t, sess.ID(), token)
})
}
@ -428,7 +467,15 @@ func Test_Session_Destroy(t *testing.T) {
// set value & save
sess.Set("name", "fenny")
id := sess.ID()
require.NoError(t, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
ctx.Request().Header.Set(store.sessionName, id)
sess, err = store.Get(ctx)
require.NoError(t, err)
@ -472,7 +519,8 @@ func Test_Session_Cookie(t *testing.T) {
}
// go test -run Test_Session_Cookie_In_Response
func Test_Session_Cookie_In_Response(t *testing.T) {
// Regression: https://github.com/gofiber/fiber/pull/1191
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
t.Parallel()
store := New()
app := fiber.New()
@ -486,12 +534,14 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
require.NoError(t, err)
sess.Set("id", "1")
require.True(t, sess.Fresh())
id := sess.ID()
require.NoError(t, sess.Save())
sess, err = store.Get(ctx)
require.NoError(t, err)
sess.Set("name", "john")
require.True(t, sess.Fresh())
require.Equal(t, id, sess.ID()) // session id should be the same
require.Equal(t, "1", sess.Get("id"))
require.Equal(t, "john", sess.Get("name"))
@ -505,24 +555,32 @@ func Test_Session_Deletes_Single_Key(t *testing.T) {
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
sess, err := store.Get(ctx)
require.NoError(t, err)
ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
id := sess.ID()
sess.Set("id", "1")
require.NoError(t, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
require.NoError(t, err)
sess.Delete("id")
require.NoError(t, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.False(t, sess.Fresh())
require.Nil(t, sess.Get("id"))
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Reset
@ -538,7 +596,6 @@ func Test_Session_Reset(t *testing.T) {
t.Parallel()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// a random session uuid
originalSessionUUIDString := ""
@ -555,6 +612,9 @@ func Test_Session_Reset(t *testing.T) {
err = freshSession.Save()
require.NoError(t, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
@ -587,6 +647,8 @@ func Test_Session_Reset(t *testing.T) {
// Check that the session id is not in the header or cookie anymore
require.Equal(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
require.Equal(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
app.ReleaseCtx(ctx)
})
}
@ -615,6 +677,12 @@ func Test_Session_Regenerate(t *testing.T) {
err = freshSession.Save()
require.NoError(t, err)
// release the context
app.ReleaseCtx(ctx)
// acquire a new context
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
@ -630,6 +698,9 @@ func Test_Session_Regenerate(t *testing.T) {
// acquiredSession.fresh should be true after regenerating
require.True(t, acquiredSession.Fresh())
// release the context
app.ReleaseCtx(ctx)
})
}

View File

@ -9,18 +9,27 @@ import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")
// mux is a global mutex for session operations.
var mux sync.Mutex
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
const (
// sessionIDContextKey is the key used to store the session ID in the context locals.
sessionIDContextKey sessionIDKey = iota
)
type Store struct {
Config
}
var mux sync.Mutex
// New creates a new session store with the provided configuration.
func New(config ...Config) *Store {
// Set default config
cfg := configDefault(config...)
@ -34,31 +43,40 @@ func New(config ...Config) *Store {
}
}
// RegisterType will allow you to encode/decode custom types
// into any Storage provider
// RegisterType registers a custom type for encoding/decoding into any storage provider.
func (*Store) RegisterType(i any) {
gob.Register(i)
}
// Get will get/create a session
// Get retrieves or creates a session for the given context.
func (s *Store) Get(c fiber.Ctx) (*Session, error) {
var fresh bool
loadData := true
var rawData []byte
var err error
id := s.getSessionID(c)
id, ok := c.Locals(sessionIDContextKey).(string)
if !ok {
id = s.getSessionID(c)
}
if len(id) == 0 {
fresh = true
var err error
if id, err = s.responseCookies(c); err != nil {
fresh := ok // Assume the session is fresh if the ID is found in locals
// Attempt to fetch session data if an ID is provided
if id != "" {
rawData, err = s.Storage.Get(id)
if err != nil {
return nil, err
}
if rawData == nil {
// Data not found, prepare to generate a new session
id = ""
}
}
// If no key exist, create new one
if len(id) == 0 {
loadData = false
// Generate a new ID if needed
if id == "" {
fresh = true // The session is fresh if a new ID is generated
id = s.KeyGenerator()
c.Locals(sessionIDContextKey, id)
}
// Create session object
@ -68,36 +86,17 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) {
sess.id = id
sess.fresh = fresh
// Fetch existing data
if loadData {
raw, err := s.Storage.Get(id)
// Unmarshal if we found data
switch {
case err != nil:
return nil, err
case raw != nil:
mux.Lock()
defer mux.Unlock()
sess.byteBuffer.Write(raw)
encCache := gob.NewDecoder(sess.byteBuffer)
err := encCache.Decode(&sess.data.Data)
if err != nil {
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
default:
// both raw and err is nil, which means id is not in the storage
sess.fresh = true
// Decode session data if found
if rawData != nil {
if err := sess.decodeSessionData(rawData); err != nil {
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
}
return sess, nil
}
// getSessionID will return the session id from:
// 1. cookie
// 2. http headers
// 3. query string
// getSessionID returns the session ID from cookies, headers, or query string.
func (s *Store) getSessionID(c fiber.Ctx) string {
id := c.Cookies(s.sessionName)
if len(id) > 0 {
@ -121,35 +120,27 @@ func (s *Store) getSessionID(c fiber.Ctx) string {
return ""
}
func (s *Store) responseCookies(c fiber.Ctx) (string, error) {
// Get key from response cookie
cookieValue := c.Response().Header.PeekCookie(s.sessionName)
if len(cookieValue) == 0 {
return "", nil
}
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
err := cookie.ParseBytes(cookieValue)
if err != nil {
return "", err
}
value := make([]byte, len(cookie.Value()))
copy(value, cookie.Value())
id := string(value)
return id, nil
}
// Reset will delete all session from the storage
// Reset deletes all sessions from the storage.
func (s *Store) Reset() error {
return s.Storage.Reset()
}
// Delete deletes a session by its id.
// Delete deletes a session by its ID.
func (s *Store) Delete(id string) error {
if id == "" {
return ErrEmptySessionID
}
return s.Storage.Delete(id)
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
mux.Lock()
defer mux.Unlock()
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}

View File

@ -63,12 +63,14 @@ func Test_Store_getSessionID(t *testing.T) {
// go test -run Test_Store_Get
// Regression: https://github.com/gofiber/fiber/issues/1408
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
func Test_Store_Get(t *testing.T) {
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
t.Parallel()
unexpectedID := "test-session-id"
// fiber instance
app := fiber.New()
t.Run("session should persisted even session is invalid", func(t *testing.T) {
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
t.Parallel()
// session store
store := New()
@ -81,7 +83,7 @@ func Test_Store_Get(t *testing.T) {
acquiredSession, err := store.Get(ctx)
require.NoError(t, err)
require.Equal(t, unexpectedID, acquiredSession.ID())
require.NotEqual(t, unexpectedID, acquiredSession.ID())
})
}