mirror of https://github.com/gofiber/fiber.git
Merge pull request from GHSA-98j2-3j3p-fw2v
* fix: token injection vulnerability GHSA-98j2-3j3p-fw2v - Ensure session IDs are securely generated server-side. - Add validation to prevent user-supplied session IDs. - Update tests to verify correct session token use. This update addresses the critical session middleware vulnerability identified in versions 2 and above of GoFiber. * test(middleware/csrf): Save session after generating new session ID This commit saves the session after generating a new session ID to ensure that the updated session ID is persisted. This change is necessary to address a critical session middleware vulnerability identified in versions 2 and above of GoFiber. * chore: Save session ID in context for middleware chain The code changes add functionality to save the newly generated session ID in the context, allowing it to be accessible to subsequent middlewares in the chain. This improvement ensures that the session ID is available for use throughout the middleware stack. * test: Fix session freshness check in session_test The code changes in `session_test.go` fix the session freshness check by updating the assertions for `sess.Fresh()` and `sess.ID()`. The previous assertions were incorrect and have been corrected to ensure the session ID remains the same and the session is not fresh. * refactor(session.go): general clean-up * chore: Revert session freshness behavior The code changes in `session_test.go` fix the session freshness check by updating the assertions for `sess.Fresh()` and `sess.ID()`. The previous assertions were incorrect and have been corrected to ensure the session ID remains the same and the session is not fresh.pull/3043/head
parent
4262f5b591
commit
7926e5bf4d
|
@ -88,6 +88,8 @@ func Test_CSRF_WithSession(t *testing.T) {
|
|||
|
||||
// the session string is no longer be 123
|
||||
newSessionIDString := sess.ID()
|
||||
sess.Save()
|
||||
|
||||
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
|
||||
|
||||
// middleware config
|
||||
|
@ -221,6 +223,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
|
|||
|
||||
// get session id
|
||||
newSessionIDString := sess.ID()
|
||||
sess.Save()
|
||||
|
||||
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
|
||||
|
||||
// middleware config
|
||||
|
|
|
@ -25,13 +25,23 @@ func Test_Session(t *testing.T) {
|
|||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// set session
|
||||
ctx.Request().Header.SetCookie(store.sessionName, "123")
|
||||
|
||||
// get session
|
||||
// Get a new session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
token := sess.ID()
|
||||
sess.Save()
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set session
|
||||
ctx.Request().Header.SetCookie(store.sessionName, token)
|
||||
|
||||
// get session
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
|
||||
// get keys
|
||||
keys := sess.Keys()
|
||||
|
@ -64,12 +74,14 @@ func Test_Session(t *testing.T) {
|
|||
|
||||
// get id
|
||||
id := sess.ID()
|
||||
utils.AssertEqual(t, "123", id)
|
||||
utils.AssertEqual(t, token, id)
|
||||
|
||||
// save the old session first
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
|
||||
// requesting entirely new context to prevent falsy tests
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
@ -108,7 +120,6 @@ func Test_Session_Types(t *testing.T) {
|
|||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, "123")
|
||||
|
@ -120,6 +131,10 @@ func Test_Session_Types(t *testing.T) {
|
|||
|
||||
// the session string is no longer be 123
|
||||
newSessionIDString := sess.ID()
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
|
||||
|
||||
type User struct {
|
||||
|
@ -177,6 +192,11 @@ func Test_Session_Types(t *testing.T) {
|
|||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
|
||||
|
||||
// get session
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
@ -203,6 +223,7 @@ func Test_Session_Types(t *testing.T) {
|
|||
utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
|
||||
utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
|
||||
utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Store_Reset
|
||||
|
@ -214,7 +235,6 @@ func Test_Session_Store_Reset(t *testing.T) {
|
|||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
|
@ -228,6 +248,12 @@ func Test_Session_Store_Reset(t *testing.T) {
|
|||
|
||||
// reset store
|
||||
utils.AssertEqual(t, nil, store.Reset())
|
||||
id := sess.ID()
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
// make sure the session is recreated
|
||||
sess, err = store.Get(ctx)
|
||||
|
@ -302,6 +328,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
|||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
token := sess.ID()
|
||||
|
||||
// expire this session in 5 seconds
|
||||
sess.SetExpiry(sessionDuration)
|
||||
|
||||
|
@ -309,7 +337,11 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
|||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// here you need to get the old session yet
|
||||
ctx.Request().Header.SetCookie(store.sessionName, token)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "john", sess.Get("name"))
|
||||
|
@ -317,10 +349,16 @@ func Test_Session_Save_Expiration(t *testing.T) {
|
|||
// just to make sure the session has been expired
|
||||
time.Sleep(sessionDuration + (10 * time.Millisecond))
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// here you should get a new session
|
||||
ctx.Request().Header.SetCookie(store.sessionName, token)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, nil, sess.Get("name"))
|
||||
utils.AssertEqual(t, true, sess.ID() != token)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -364,7 +402,15 @@ func Test_Session_Destroy(t *testing.T) {
|
|||
|
||||
// set value & save
|
||||
sess.Set("name", "fenny")
|
||||
id := sess.ID()
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
ctx.Request().Header.Set(store.sessionName, id)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
|
@ -408,7 +454,8 @@ func Test_Session_Cookie(t *testing.T) {
|
|||
}
|
||||
|
||||
// go test -run Test_Session_Cookie_In_Response
|
||||
func Test_Session_Cookie_In_Response(t *testing.T) {
|
||||
// Regression: https://github.com/gofiber/fiber/pull/1191
|
||||
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
|
||||
t.Parallel()
|
||||
store := New()
|
||||
app := fiber.New()
|
||||
|
@ -421,6 +468,7 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
|
|||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Set("id", "1")
|
||||
id := sess.ID()
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
|
@ -428,8 +476,9 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
|
|||
utils.AssertEqual(t, nil, err)
|
||||
sess.Set("name", "john")
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, id, sess.ID()) // session id should be the same
|
||||
|
||||
utils.AssertEqual(t, "1", sess.Get("id"))
|
||||
utils.AssertEqual(t, sess.ID() != "1", true)
|
||||
utils.AssertEqual(t, "john", sess.Get("name"))
|
||||
}
|
||||
|
||||
|
@ -441,24 +490,31 @@ func Test_Session_Deletes_Single_Key(t *testing.T) {
|
|||
app := fiber.New()
|
||||
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
|
||||
|
||||
id := sess.ID()
|
||||
sess.Set("id", "1")
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Delete("id")
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Get("id"))
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Reset
|
||||
|
@ -475,6 +531,9 @@ func Test_Session_Reset(t *testing.T) {
|
|||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
// a random session uuid
|
||||
originalSessionUUIDString := ""
|
||||
|
||||
|
@ -491,6 +550,9 @@ func Test_Session_Reset(t *testing.T) {
|
|||
err = freshSession.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
|
||||
|
||||
|
@ -524,6 +586,8 @@ func Test_Session_Reset(t *testing.T) {
|
|||
// Check that the session id is not in the header or cookie anymore
|
||||
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
|
||||
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -551,6 +615,12 @@ func Test_Session_Regenerate(t *testing.T) {
|
|||
err = freshSession.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// release the context
|
||||
app.ReleaseCtx(ctx)
|
||||
|
||||
// acquire a new context
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
|
||||
|
||||
|
@ -566,6 +636,9 @@ func Test_Session_Regenerate(t *testing.T) {
|
|||
|
||||
// acquiredSession.fresh should be true after regenerating
|
||||
utils.AssertEqual(t, true, acquiredSession.Fresh())
|
||||
|
||||
// release the context
|
||||
app.ReleaseCtx(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -9,19 +9,27 @@ import (
|
|||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ErrEmptySessionID is an error that occurs when the session ID is empty.
|
||||
var ErrEmptySessionID = errors.New("session id cannot be empty")
|
||||
|
||||
// mux is a global mutex for session operations.
|
||||
var mux sync.Mutex
|
||||
|
||||
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
|
||||
type sessionIDKey int
|
||||
|
||||
const (
|
||||
// sessionIDContextKey is the key used to store the session ID in the context locals.
|
||||
sessionIDContextKey sessionIDKey = iota
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
Config
|
||||
}
|
||||
|
||||
var mux sync.Mutex
|
||||
|
||||
// New creates a new session store with the provided configuration.
|
||||
func New(config ...Config) *Store {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
@ -35,31 +43,40 @@ func New(config ...Config) *Store {
|
|||
}
|
||||
}
|
||||
|
||||
// RegisterType will allow you to encode/decode custom types
|
||||
// into any Storage provider
|
||||
// RegisterType registers a custom type for encoding/decoding into any storage provider.
|
||||
func (*Store) RegisterType(i interface{}) {
|
||||
gob.Register(i)
|
||||
}
|
||||
|
||||
// Get will get/create a session
|
||||
// Get retrieves or creates a session for the given context.
|
||||
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
||||
var fresh bool
|
||||
loadData := true
|
||||
var rawData []byte
|
||||
var err error
|
||||
|
||||
id := s.getSessionID(c)
|
||||
id, ok := c.Locals(sessionIDContextKey).(string)
|
||||
if !ok {
|
||||
id = s.getSessionID(c)
|
||||
}
|
||||
|
||||
if len(id) == 0 {
|
||||
fresh = true
|
||||
var err error
|
||||
if id, err = s.responseCookies(c); err != nil {
|
||||
fresh := ok // Assume the session is fresh if the ID is found in locals
|
||||
|
||||
// Attempt to fetch session data if an ID is provided
|
||||
if id != "" {
|
||||
rawData, err = s.Storage.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rawData == nil {
|
||||
// Data not found, prepare to generate a new session
|
||||
id = ""
|
||||
}
|
||||
}
|
||||
|
||||
// If no key exist, create new one
|
||||
if len(id) == 0 {
|
||||
loadData = false
|
||||
// Generate a new ID if needed
|
||||
if id == "" {
|
||||
fresh = true // The session is fresh if a new ID is generated
|
||||
id = s.KeyGenerator()
|
||||
c.Locals(sessionIDContextKey, id)
|
||||
}
|
||||
|
||||
// Create session object
|
||||
|
@ -69,34 +86,17 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
|||
sess.id = id
|
||||
sess.fresh = fresh
|
||||
|
||||
// Fetch existing data
|
||||
if loadData {
|
||||
raw, err := s.Storage.Get(id)
|
||||
// Unmarshal if we found data
|
||||
if raw != nil && err == nil {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
|
||||
encCache := gob.NewDecoder(sess.byteBuffer)
|
||||
err := encCache.Decode(&sess.data.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
// both raw and err is nil, which means id is not in the storage
|
||||
sess.fresh = true
|
||||
// Decode session data if found
|
||||
if rawData != nil {
|
||||
if err := sess.decodeSessionData(rawData); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// getSessionID will return the session id from:
|
||||
// 1. cookie
|
||||
// 2. http headers
|
||||
// 3. query string
|
||||
// getSessionID returns the session ID from cookies, headers, or query string.
|
||||
func (s *Store) getSessionID(c *fiber.Ctx) string {
|
||||
id := c.Cookies(s.sessionName)
|
||||
if len(id) > 0 {
|
||||
|
@ -120,35 +120,27 @@ func (s *Store) getSessionID(c *fiber.Ctx) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (s *Store) responseCookies(c *fiber.Ctx) (string, error) {
|
||||
// Get key from response cookie
|
||||
cookieValue := c.Response().Header.PeekCookie(s.sessionName)
|
||||
if len(cookieValue) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
err := cookie.ParseBytes(cookieValue)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
value := make([]byte, len(cookie.Value()))
|
||||
copy(value, cookie.Value())
|
||||
id := string(value)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Reset will delete all session from the storage
|
||||
// Reset deletes all sessions from the storage.
|
||||
func (s *Store) Reset() error {
|
||||
return s.Storage.Reset()
|
||||
}
|
||||
|
||||
// Delete deletes a session by its id.
|
||||
// Delete deletes a session by its ID.
|
||||
func (s *Store) Delete(id string) error {
|
||||
if id == "" {
|
||||
return ErrEmptySessionID
|
||||
}
|
||||
return s.Storage.Delete(id)
|
||||
}
|
||||
|
||||
// decodeSessionData decodes the session data from raw bytes.
|
||||
func (s *Session) decodeSessionData(rawData []byte) error {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
|
||||
encCache := gob.NewDecoder(s.byteBuffer)
|
||||
if err := encCache.Decode(&s.data.Data); err != nil {
|
||||
return fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -64,12 +64,13 @@ func TestStore_getSessionID(t *testing.T) {
|
|||
|
||||
// go test -run TestStore_Get
|
||||
// Regression: https://github.com/gofiber/fiber/issues/1408
|
||||
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
|
||||
func TestStore_Get(t *testing.T) {
|
||||
t.Parallel()
|
||||
unexpectedID := "test-session-id"
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
t.Run("session should persisted even session is invalid", func(t *testing.T) {
|
||||
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
|
@ -82,7 +83,7 @@ func TestStore_Get(t *testing.T) {
|
|||
acquiredSession, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
utils.AssertEqual(t, unexpectedID, acquiredSession.ID())
|
||||
utils.AssertEqual(t, acquiredSession.ID() != unexpectedID, true)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue