mirror of https://github.com/gofiber/fiber.git
514 lines
12 KiB
Go
514 lines
12 KiB
Go
package session
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/utils/v2"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// Session represents a user session.
|
|
type Session struct {
|
|
ctx fiber.Ctx // fiber context
|
|
config *Store // store configuration
|
|
data *data // key value data
|
|
id string // session id
|
|
idleTimeout time.Duration // idleTimeout of this session
|
|
mu sync.RWMutex // Mutex to protect non-data fields
|
|
fresh bool // if new session
|
|
}
|
|
|
|
type absExpirationKeyType int
|
|
|
|
const (
|
|
// sessionIDContextKey is the key used to store the session ID in the context locals.
|
|
absExpirationKey absExpirationKeyType = iota
|
|
)
|
|
|
|
// Session pool for reusing byte buffers.
|
|
var byteBufferPool = sync.Pool{
|
|
New: func() any {
|
|
return new(bytes.Buffer)
|
|
},
|
|
}
|
|
|
|
var sessionPool = sync.Pool{
|
|
New: func() any {
|
|
return &Session{}
|
|
},
|
|
}
|
|
|
|
// acquireSession returns a new Session from the pool.
|
|
//
|
|
// Returns:
|
|
// - *Session: The session object.
|
|
//
|
|
// Usage:
|
|
//
|
|
// s := acquireSession()
|
|
func acquireSession() *Session {
|
|
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
|
if s.data == nil {
|
|
s.data = acquireData()
|
|
}
|
|
s.fresh = true
|
|
return s
|
|
}
|
|
|
|
// Release releases the session back to the pool.
|
|
//
|
|
// This function should be called after the session is no longer needed.
|
|
// This function is used to reduce the number of allocations and
|
|
// to improve the performance of the session store.
|
|
//
|
|
// The session should not be used after calling this function.
|
|
//
|
|
// Important: The Release function should only be used when accessing the session directly,
|
|
// for example, when you have called func (s *Session) Get(ctx) to get the session.
|
|
// It should not be used when using the session with a *Middleware handler in the request
|
|
// call stack, as the middleware will still need to access the session.
|
|
//
|
|
// Usage:
|
|
//
|
|
// sess := session.Get(ctx)
|
|
// defer sess.Release()
|
|
func (s *Session) Release() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
releaseSession(s)
|
|
}
|
|
|
|
func releaseSession(s *Session) {
|
|
s.mu.Lock()
|
|
s.id = ""
|
|
s.idleTimeout = 0
|
|
s.ctx = nil
|
|
s.config = nil
|
|
if s.data != nil {
|
|
s.data.Reset()
|
|
}
|
|
s.mu.Unlock()
|
|
sessionPool.Put(s)
|
|
}
|
|
|
|
// Fresh returns whether the session is new
|
|
//
|
|
// Returns:
|
|
// - bool: True if the session is fresh, otherwise false.
|
|
//
|
|
// Usage:
|
|
//
|
|
// isFresh := s.Fresh()
|
|
func (s *Session) Fresh() bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.fresh
|
|
}
|
|
|
|
// ID returns the session ID
|
|
//
|
|
// Returns:
|
|
// - string: The session ID.
|
|
//
|
|
// Usage:
|
|
//
|
|
// id := s.ID()
|
|
func (s *Session) ID() string {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.id
|
|
}
|
|
|
|
// Get returns the value associated with the given key.
|
|
//
|
|
// Parameters:
|
|
// - key: The key to retrieve.
|
|
//
|
|
// Returns:
|
|
// - any: The value associated with the key.
|
|
//
|
|
// Usage:
|
|
//
|
|
// value := s.Get("key")
|
|
func (s *Session) Get(key any) any {
|
|
if s.data == nil {
|
|
return nil
|
|
}
|
|
return s.data.Get(key)
|
|
}
|
|
|
|
// Set updates or creates a new key-value pair in the session.
|
|
//
|
|
// Parameters:
|
|
// - key: The key to set.
|
|
// - val: The value to set.
|
|
//
|
|
// Usage:
|
|
//
|
|
// s.Set("key", "value")
|
|
func (s *Session) Set(key, val any) {
|
|
if s.data == nil {
|
|
return
|
|
}
|
|
s.data.Set(key, val)
|
|
}
|
|
|
|
// Delete removes the key-value pair from the session.
|
|
//
|
|
// Parameters:
|
|
// - key: The key to delete.
|
|
//
|
|
// Usage:
|
|
//
|
|
// s.Delete("key")
|
|
func (s *Session) Delete(key any) {
|
|
if s.data == nil {
|
|
return
|
|
}
|
|
s.data.Delete(key)
|
|
}
|
|
|
|
// Destroy deletes the session from storage and expires the session cookie.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the destruction fails.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := s.Destroy()
|
|
func (s *Session) Destroy() error {
|
|
if s.data == nil {
|
|
return nil
|
|
}
|
|
|
|
// Reset local data
|
|
s.data.Reset()
|
|
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
// Use external Storage if exist
|
|
if err := s.config.Storage.Delete(s.id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Expire session
|
|
s.delSession()
|
|
return nil
|
|
}
|
|
|
|
// Regenerate generates a new session id and deletes the old one from storage.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the regeneration fails.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := s.Regenerate()
|
|
func (s *Session) Regenerate() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Delete old id from storage
|
|
if err := s.config.Storage.Delete(s.id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Generate a new session, and set session.fresh to true
|
|
s.refresh()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Reset generates a new session id, deletes the old one from storage, and resets the associated data.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the reset fails.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := s.Reset()
|
|
func (s *Session) Reset() error {
|
|
// Reset local data
|
|
if s.data != nil {
|
|
s.data.Reset()
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Reset expiration
|
|
s.idleTimeout = 0
|
|
|
|
// Delete old id from storage
|
|
if err := s.config.Storage.Delete(s.id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Expire session
|
|
s.delSession()
|
|
|
|
// Generate a new session, and set session.fresh to true
|
|
s.refresh()
|
|
|
|
return nil
|
|
}
|
|
|
|
// refresh generates a new session, and sets session.fresh to be true.
|
|
func (s *Session) refresh() {
|
|
s.id = s.config.KeyGenerator()
|
|
s.fresh = true
|
|
}
|
|
|
|
// Save saves the session data and updates the cookie
|
|
//
|
|
// Note: If the session is being used in the handler, calling Save will have
|
|
// no effect and the session will automatically be saved when the handler returns.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the save operation fails.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := s.Save()
|
|
func (s *Session) Save() error {
|
|
if s.ctx == nil {
|
|
return s.saveSession()
|
|
}
|
|
|
|
// If the session is being used in the handler, it should not be saved
|
|
if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok {
|
|
if m.Session == s {
|
|
// Session is in use, so we do nothing and return
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return s.saveSession()
|
|
}
|
|
|
|
// saveSession encodes session data to saves it to storage.
|
|
func (s *Session) saveSession() error {
|
|
if s.data == nil {
|
|
return nil
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Set idleTimeout if not already set
|
|
if s.idleTimeout <= 0 {
|
|
s.idleTimeout = s.config.IdleTimeout
|
|
}
|
|
|
|
// Update client cookie
|
|
s.setSession()
|
|
|
|
// Encode session data
|
|
s.data.RLock()
|
|
encodedBytes, err := s.encodeSessionData()
|
|
s.data.RUnlock()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encode data: %w", err)
|
|
}
|
|
|
|
// Pass copied bytes with session id to provider
|
|
return s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout)
|
|
}
|
|
|
|
// Keys retrieves all keys in the current session.
|
|
//
|
|
// Returns:
|
|
// - []string: A slice of all keys in the session.
|
|
//
|
|
// Usage:
|
|
//
|
|
// keys := s.Keys()
|
|
func (s *Session) Keys() []any {
|
|
if s.data == nil {
|
|
return []any{}
|
|
}
|
|
return s.data.Keys()
|
|
}
|
|
|
|
// SetIdleTimeout used when saving the session on the next call to `Save()`.
|
|
//
|
|
// Parameters:
|
|
// - idleTimeout: The duration for the idle timeout.
|
|
//
|
|
// Usage:
|
|
//
|
|
// s.SetIdleTimeout(time.Hour)
|
|
func (s *Session) SetIdleTimeout(idleTimeout time.Duration) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.idleTimeout = idleTimeout
|
|
}
|
|
|
|
func (s *Session) setSession() {
|
|
if s.ctx == nil {
|
|
return
|
|
}
|
|
|
|
if s.config.source == SourceHeader {
|
|
s.ctx.Request().Header.SetBytesV(s.config.sessionName, []byte(s.id))
|
|
s.ctx.Response().Header.SetBytesV(s.config.sessionName, []byte(s.id))
|
|
} else {
|
|
fcookie := fasthttp.AcquireCookie()
|
|
fcookie.SetKey(s.config.sessionName)
|
|
fcookie.SetValue(s.id)
|
|
fcookie.SetPath(s.config.CookiePath)
|
|
fcookie.SetDomain(s.config.CookieDomain)
|
|
// Cookies are also session cookies if they do not specify the Expires or Max-Age attribute.
|
|
// refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
|
if !s.config.CookieSessionOnly {
|
|
fcookie.SetMaxAge(int(s.idleTimeout.Seconds()))
|
|
fcookie.SetExpire(time.Now().Add(s.idleTimeout))
|
|
}
|
|
fcookie.SetSecure(s.config.CookieSecure)
|
|
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
|
|
|
|
switch utils.ToLower(s.config.CookieSameSite) {
|
|
case "strict":
|
|
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
|
case "none":
|
|
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
|
default:
|
|
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
|
}
|
|
s.ctx.Response().Header.SetCookie(fcookie)
|
|
fasthttp.ReleaseCookie(fcookie)
|
|
}
|
|
}
|
|
|
|
func (s *Session) delSession() {
|
|
if s.ctx == nil {
|
|
return
|
|
}
|
|
|
|
if s.config.source == SourceHeader {
|
|
s.ctx.Request().Header.Del(s.config.sessionName)
|
|
s.ctx.Response().Header.Del(s.config.sessionName)
|
|
} else {
|
|
s.ctx.Request().Header.DelCookie(s.config.sessionName)
|
|
s.ctx.Response().Header.DelCookie(s.config.sessionName)
|
|
|
|
fcookie := fasthttp.AcquireCookie()
|
|
fcookie.SetKey(s.config.sessionName)
|
|
fcookie.SetPath(s.config.CookiePath)
|
|
fcookie.SetDomain(s.config.CookieDomain)
|
|
fcookie.SetMaxAge(-1)
|
|
fcookie.SetExpire(time.Now().Add(-1 * time.Minute))
|
|
fcookie.SetSecure(s.config.CookieSecure)
|
|
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
|
|
|
|
switch utils.ToLower(s.config.CookieSameSite) {
|
|
case "strict":
|
|
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
|
case "none":
|
|
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
|
default:
|
|
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
|
}
|
|
|
|
s.ctx.Response().Header.SetCookie(fcookie)
|
|
fasthttp.ReleaseCookie(fcookie)
|
|
}
|
|
}
|
|
|
|
// decodeSessionData decodes session data from raw bytes
|
|
//
|
|
// Parameters:
|
|
// - rawData: The raw byte data to decode.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the decoding fails.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := s.decodeSessionData(rawData)
|
|
func (s *Session) decodeSessionData(rawData []byte) error {
|
|
byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
|
defer byteBufferPool.Put(byteBuffer)
|
|
defer byteBuffer.Reset()
|
|
_, _ = byteBuffer.Write(rawData)
|
|
decCache := gob.NewDecoder(byteBuffer)
|
|
if err := decCache.Decode(&s.data.Data); err != nil {
|
|
return fmt.Errorf("failed to decode session data: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// encodeSessionData encodes session data to raw bytes
|
|
//
|
|
// Parameters:
|
|
// - rawData: The raw byte data to encode.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the encoding fails.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := s.encodeSessionData(rawData)
|
|
func (s *Session) encodeSessionData() ([]byte, error) {
|
|
byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
|
defer byteBufferPool.Put(byteBuffer)
|
|
defer byteBuffer.Reset()
|
|
encCache := gob.NewEncoder(byteBuffer)
|
|
if err := encCache.Encode(&s.data.Data); err != nil {
|
|
return nil, fmt.Errorf("failed to encode session data: %w", err)
|
|
}
|
|
// Copy the bytes
|
|
// Copy the data in buffer
|
|
encodedBytes := make([]byte, byteBuffer.Len())
|
|
copy(encodedBytes, byteBuffer.Bytes())
|
|
|
|
return encodedBytes, nil
|
|
}
|
|
|
|
// absExpiration returns the session absolute expiration time or a zero time if not set.
|
|
//
|
|
// Returns:
|
|
// - time.Time: The session absolute expiration time. Zero time if not set.
|
|
//
|
|
// Usage:
|
|
//
|
|
// expiration := s.absExpiration()
|
|
func (s *Session) absExpiration() time.Time {
|
|
absExpiration, ok := s.Get(absExpirationKey).(time.Time)
|
|
if ok {
|
|
return absExpiration
|
|
}
|
|
return time.Time{}
|
|
}
|
|
|
|
// isAbsExpired returns true if the session is expired.
|
|
//
|
|
// If the session has an absolute expiration time set, this function will return true if the
|
|
// current time is after the absolute expiration time.
|
|
//
|
|
// Returns:
|
|
// - bool: True if the session is expired, otherwise false.
|
|
func (s *Session) isAbsExpired() bool {
|
|
absExpiration := s.absExpiration()
|
|
return !absExpiration.IsZero() && time.Now().After(absExpiration)
|
|
}
|
|
|
|
// setAbsoluteExpiration sets the absolute session expiration time.
|
|
//
|
|
// Parameters:
|
|
// - expiration: The session expiration time.
|
|
//
|
|
// Usage:
|
|
//
|
|
// s.setExpiration(time.Now().Add(time.Hour))
|
|
func (s *Session) setAbsExpiration(absExpiration time.Time) {
|
|
s.Set(absExpirationKey, absExpiration)
|
|
}
|