Merge pull request from GHSA-98j2-3j3p-fw2v

* fix: token injection vulnerability GHSA-98j2-3j3p-fw2v

- Ensure session IDs are securely generated server-side.
- Add validation to prevent user-supplied session IDs.
- Update tests to verify correct session token use.

This update addresses the critical session middleware vulnerability identified in versions 2 and above of GoFiber.

* test(middleware/csrf): Save session after generating new session ID

This commit saves the session after generating a new session ID to ensure that the updated session ID is persisted. This change is necessary to address a critical session middleware vulnerability identified in versions 2 and above of GoFiber.

* chore: Save session ID in context for middleware chain

The code changes add functionality to save the newly generated session ID in the context, allowing it to be accessible to subsequent middlewares in the chain. This improvement ensures that the session ID is available for use throughout the middleware stack.

* test: Fix session freshness check in session_test

The code changes in `session_test.go` fix the session freshness check by updating the assertions for `sess.Fresh()` and `sess.ID()`. The previous assertions were incorrect and have been corrected to ensure the session ID remains the same and the session is not fresh.

* refactor(session.go): general clean-up

* chore: Revert session freshness behavior

The code changes in `session_test.go` fix the session freshness check by updating the assertions for `sess.Fresh()` and `sess.ID()`. The previous assertions were incorrect and have been corrected to ensure the session ID remains the same and the session is not fresh.
pull/3043/head
Jason McNeil 2024-06-26 04:17:41 -03:00 committed by GitHub
parent 4262f5b591
commit 7926e5bf4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 145 additions and 75 deletions

View File

@ -88,6 +88,8 @@ func Test_CSRF_WithSession(t *testing.T) {
// the session string is no longer be 123
newSessionIDString := sess.ID()
sess.Save()
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config
@ -221,6 +223,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
// get session id
newSessionIDString := sess.ID()
sess.Save()
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config

View File

@ -25,13 +25,23 @@ func Test_Session(t *testing.T) {
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set session
ctx.Request().Header.SetCookie(store.sessionName, "123")
// get session
// Get a new session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
token := sess.ID()
sess.Save()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set session
ctx.Request().Header.SetCookie(store.sessionName, token)
// get session
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
// get keys
keys := sess.Keys()
@ -64,12 +74,14 @@ func Test_Session(t *testing.T) {
// get id
id := sess.ID()
utils.AssertEqual(t, "123", id)
utils.AssertEqual(t, token, id)
// save the old session first
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
// requesting entirely new context to prevent falsy tests
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
@ -108,7 +120,6 @@ func Test_Session_Types(t *testing.T) {
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, "123")
@ -120,6 +131,10 @@ func Test_Session_Types(t *testing.T) {
// the session string is no longer be 123
newSessionIDString := sess.ID()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
type User struct {
@ -177,6 +192,11 @@ func Test_Session_Types(t *testing.T) {
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
// get session
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
@ -203,6 +223,7 @@ func Test_Session_Types(t *testing.T) {
utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Store_Reset
@ -214,7 +235,6 @@ func Test_Session_Store_Reset(t *testing.T) {
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
@ -228,6 +248,12 @@ func Test_Session_Store_Reset(t *testing.T) {
// reset store
utils.AssertEqual(t, nil, store.Reset())
id := sess.ID()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie(store.sessionName, id)
// make sure the session is recreated
sess, err = store.Get(ctx)
@ -302,6 +328,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
// set value
sess.Set("name", "john")
token := sess.ID()
// expire this session in 5 seconds
sess.SetExpiry(sessionDuration)
@ -309,7 +337,11 @@ func Test_Session_Save_Expiration(t *testing.T) {
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you need to get the old session yet
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "john", sess.Get("name"))
@ -317,10 +349,16 @@ func Test_Session_Save_Expiration(t *testing.T) {
// just to make sure the session has been expired
time.Sleep(sessionDuration + (10 * time.Millisecond))
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// here you should get a new session
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Get("name"))
utils.AssertEqual(t, true, sess.ID() != token)
})
}
@ -364,7 +402,15 @@ func Test_Session_Destroy(t *testing.T) {
// set value & save
sess.Set("name", "fenny")
id := sess.ID()
utils.AssertEqual(t, nil, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
ctx.Request().Header.Set(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
@ -408,7 +454,8 @@ func Test_Session_Cookie(t *testing.T) {
}
// go test -run Test_Session_Cookie_In_Response
func Test_Session_Cookie_In_Response(t *testing.T) {
// Regression: https://github.com/gofiber/fiber/pull/1191
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
t.Parallel()
store := New()
app := fiber.New()
@ -421,6 +468,7 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("id", "1")
id := sess.ID()
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Save())
@ -428,8 +476,9 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
utils.AssertEqual(t, nil, err)
sess.Set("name", "john")
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, id, sess.ID()) // session id should be the same
utils.AssertEqual(t, "1", sess.Get("id"))
utils.AssertEqual(t, sess.ID() != "1", true)
utils.AssertEqual(t, "john", sess.Get("name"))
}
@ -441,24 +490,31 @@ func Test_Session_Deletes_Single_Key(t *testing.T) {
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
id := sess.ID()
sess.Set("id", "1")
utils.AssertEqual(t, nil, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Delete("id")
utils.AssertEqual(t, nil, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
utils.AssertEqual(t, nil, sess.Get("id"))
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Reset
@ -475,6 +531,9 @@ func Test_Session_Reset(t *testing.T) {
defer app.ReleaseCtx(ctx)
t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
t.Parallel()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// a random session uuid
originalSessionUUIDString := ""
@ -491,6 +550,9 @@ func Test_Session_Reset(t *testing.T) {
err = freshSession.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
@ -524,6 +586,8 @@ func Test_Session_Reset(t *testing.T) {
// Check that the session id is not in the header or cookie anymore
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
app.ReleaseCtx(ctx)
})
}
@ -551,6 +615,12 @@ func Test_Session_Regenerate(t *testing.T) {
err = freshSession.Save()
utils.AssertEqual(t, nil, err)
// release the context
app.ReleaseCtx(ctx)
// acquire a new context
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
@ -566,6 +636,9 @@ func Test_Session_Regenerate(t *testing.T) {
// acquiredSession.fresh should be true after regenerating
utils.AssertEqual(t, true, acquiredSession.Fresh())
// release the context
app.ReleaseCtx(ctx)
})
}

View File

@ -9,19 +9,27 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")
// mux is a global mutex for session operations.
var mux sync.Mutex
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
const (
// sessionIDContextKey is the key used to store the session ID in the context locals.
sessionIDContextKey sessionIDKey = iota
)
type Store struct {
Config
}
var mux sync.Mutex
// New creates a new session store with the provided configuration.
func New(config ...Config) *Store {
// Set default config
cfg := configDefault(config...)
@ -35,31 +43,40 @@ func New(config ...Config) *Store {
}
}
// RegisterType will allow you to encode/decode custom types
// into any Storage provider
// RegisterType registers a custom type for encoding/decoding into any storage provider.
func (*Store) RegisterType(i interface{}) {
gob.Register(i)
}
// Get will get/create a session
// Get retrieves or creates a session for the given context.
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
var fresh bool
loadData := true
var rawData []byte
var err error
id := s.getSessionID(c)
id, ok := c.Locals(sessionIDContextKey).(string)
if !ok {
id = s.getSessionID(c)
}
if len(id) == 0 {
fresh = true
var err error
if id, err = s.responseCookies(c); err != nil {
fresh := ok // Assume the session is fresh if the ID is found in locals
// Attempt to fetch session data if an ID is provided
if id != "" {
rawData, err = s.Storage.Get(id)
if err != nil {
return nil, err
}
if rawData == nil {
// Data not found, prepare to generate a new session
id = ""
}
}
// If no key exist, create new one
if len(id) == 0 {
loadData = false
// Generate a new ID if needed
if id == "" {
fresh = true // The session is fresh if a new ID is generated
id = s.KeyGenerator()
c.Locals(sessionIDContextKey, id)
}
// Create session object
@ -69,34 +86,17 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
sess.id = id
sess.fresh = fresh
// Fetch existing data
if loadData {
raw, err := s.Storage.Get(id)
// Unmarshal if we found data
if raw != nil && err == nil {
mux.Lock()
defer mux.Unlock()
_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(sess.byteBuffer)
err := encCache.Decode(&sess.data.Data)
if err != nil {
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
} else if err != nil {
return nil, err
} else {
// both raw and err is nil, which means id is not in the storage
sess.fresh = true
// Decode session data if found
if rawData != nil {
if err := sess.decodeSessionData(rawData); err != nil {
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
}
return sess, nil
}
// getSessionID will return the session id from:
// 1. cookie
// 2. http headers
// 3. query string
// getSessionID returns the session ID from cookies, headers, or query string.
func (s *Store) getSessionID(c *fiber.Ctx) string {
id := c.Cookies(s.sessionName)
if len(id) > 0 {
@ -120,35 +120,27 @@ func (s *Store) getSessionID(c *fiber.Ctx) string {
return ""
}
func (s *Store) responseCookies(c *fiber.Ctx) (string, error) {
// Get key from response cookie
cookieValue := c.Response().Header.PeekCookie(s.sessionName)
if len(cookieValue) == 0 {
return "", nil
}
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
err := cookie.ParseBytes(cookieValue)
if err != nil {
return "", err
}
value := make([]byte, len(cookie.Value()))
copy(value, cookie.Value())
id := string(value)
return id, nil
}
// Reset will delete all session from the storage
// Reset deletes all sessions from the storage.
func (s *Store) Reset() error {
return s.Storage.Reset()
}
// Delete deletes a session by its id.
// Delete deletes a session by its ID.
func (s *Store) Delete(id string) error {
if id == "" {
return ErrEmptySessionID
}
return s.Storage.Delete(id)
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
mux.Lock()
defer mux.Unlock()
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}

View File

@ -64,12 +64,13 @@ func TestStore_getSessionID(t *testing.T) {
// go test -run TestStore_Get
// Regression: https://github.com/gofiber/fiber/issues/1408
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
func TestStore_Get(t *testing.T) {
t.Parallel()
unexpectedID := "test-session-id"
// fiber instance
app := fiber.New()
t.Run("session should persisted even session is invalid", func(t *testing.T) {
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
t.Parallel()
// session store
store := New()
@ -82,7 +83,7 @@ func TestStore_Get(t *testing.T) {
acquiredSession, err := store.Get(ctx)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, unexpectedID, acquiredSession.ID())
utils.AssertEqual(t, acquiredSession.ID() != unexpectedID, true)
})
}