package session import ( "encoding/gob" "errors" "fmt" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) // ErrEmptySessionID is an error that occurs when the session ID is empty. var ( ErrEmptySessionID = errors.New("session ID cannot be empty") ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware") ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") ) // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int const ( // sessionIDContextKey is the key used to store the session ID in the context locals. sessionIDContextKey sessionIDKey = iota ) type Store struct { Config } // New creates a new session store with the provided configuration. // // Parameters: // - config: Variadic parameter to override default config. // // Returns: // - *Store: The session store. // // Usage: // // store := session.New() func NewStore(config ...Config) *Store { // Set default config cfg := configDefault(config...) if cfg.Storage == nil { cfg.Storage = memory.New() } store := &Store{ Config: cfg, } if cfg.AbsoluteTimeout > 0 { store.RegisterType(absExpirationKey) store.RegisterType(time.Time{}) } return store } // RegisterType registers a custom type for encoding/decoding into any storage provider. // // Parameters: // - i: The custom type to register. // // Usage: // // store.RegisterType(MyCustomType{}) func (*Store) RegisterType(i any) { gob.Register(i) } // Get will get/create a session. // // This function will return an ErrSessionAlreadyLoadedByMiddleware if // the session is already loaded by the middleware. // // Parameters: // - c: The Fiber context. // // Returns: // - *Session: The session object. // - error: An error if the session retrieval fails or if the session is already loaded by the middleware. // // Usage: // // sess, err := store.Get(c) // if err != nil { // // handle error // } func (s *Store) Get(c fiber.Ctx) (*Session, error) { // If session is already loaded in the context, // it should not be loaded again _, ok := c.Locals(middlewareContextKey).(*Middleware) if ok { return nil, ErrSessionAlreadyLoadedByMiddleware } return s.getSession(c) } // getSession retrieves a session based on the context. // // Parameters: // - c: The Fiber context. // // Returns: // - *Session: The session object. // - error: An error if the session retrieval fails. // // Usage: // // sess, err := store.getSession(c) // if err != nil { // // handle error // } func (s *Store) getSession(c fiber.Ctx) (*Session, error) { var rawData []byte var err error id, ok := c.Locals(sessionIDContextKey).(string) if !ok { id = s.getSessionID(c) } fresh := ok // Assume the session is fresh if the ID is found in locals // Attempt to fetch session data if an ID is provided if id != "" { rawData, err = s.Storage.Get(id) if err != nil { return nil, err } if rawData == nil { // Data not found, prepare to generate a new session id = "" } } // Generate a new ID if needed if id == "" { fresh = true // The session is fresh if a new ID is generated id = s.KeyGenerator() c.Locals(sessionIDContextKey, id) } // Create session object sess := acquireSession() sess.mu.Lock() sess.ctx = c sess.config = s sess.id = id sess.fresh = fresh // Decode session data if found if rawData != nil { sess.data.Lock() err := sess.decodeSessionData(rawData) sess.data.Unlock() if err != nil { sess.mu.Unlock() sess.Release() return nil, fmt.Errorf("failed to decode session data: %w", err) } } sess.mu.Unlock() if fresh && s.AbsoluteTimeout > 0 { sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) } else if sess.isAbsExpired() { if err := sess.Reset(); err != nil { return nil, fmt.Errorf("failed to reset session: %w", err) } sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) } return sess, nil } // getSessionID returns the session ID from cookies, headers, or query string. // // Parameters: // - c: The Fiber context. // // Returns: // - string: The session ID. // // Usage: // // id := store.getSessionID(c) func (s *Store) getSessionID(c fiber.Ctx) string { id := c.Cookies(s.sessionName) if len(id) > 0 { return utils.CopyString(id) } if s.source == SourceHeader { id = string(c.Request().Header.Peek(s.sessionName)) if len(id) > 0 { return id } } if s.source == SourceURLQuery { id = fiber.Query[string](c, s.sessionName) if len(id) > 0 { return utils.CopyString(id) } } return "" } // Reset deletes all sessions from the storage. // // Returns: // - error: An error if the reset operation fails. // // Usage: // // err := store.Reset() // if err != nil { // // handle error // } func (s *Store) Reset() error { return s.Storage.Reset() } // Delete deletes a session by its ID. // // Parameters: // - id: The unique identifier of the session. // // Returns: // - error: An error if the deletion fails or if the session ID is empty. // // Usage: // // err := store.Delete(id) // if err != nil { // // handle error // } func (s *Store) Delete(id string) error { if id == "" { return ErrEmptySessionID } return s.Storage.Delete(id) } // GetByID retrieves a session by its ID from the storage. // If the session is not found, it returns nil and an error. // // Unlike session middleware methods, this function does not automatically: // // - Load the session into the request context. // // - Save the session data to the storage or update the client cookie. // // Important Notes: // // - The session object returned by GetByID does not have a context associated with it. // // - When using this method alongside session middleware, there is a potential for collisions, // so be mindful of interactions between manually retrieved sessions and middleware-managed sessions. // // - If you modify a session returned by GetByID, you must call session.Save() to persist the changes. // // - When you are done with the session, you should call session.Release() to release the session back to the pool. // // Parameters: // - id: The unique identifier of the session. // // Returns: // - *Session: The session object if found, otherwise nil. // - error: An error if the session retrieval fails or if the session ID is empty. // // Usage: // // sess, err := store.GetByID(id) // if err != nil { // // handle error // } func (s *Store) GetByID(id string) (*Session, error) { if id == "" { return nil, ErrEmptySessionID } rawData, err := s.Storage.Get(id) if err != nil { return nil, err } if rawData == nil { return nil, ErrSessionIDNotFoundInStore } sess := acquireSession() sess.mu.Lock() sess.config = s sess.id = id sess.fresh = false sess.data.Lock() decodeErr := sess.decodeSessionData(rawData) sess.data.Unlock() sess.mu.Unlock() if decodeErr != nil { sess.Release() return nil, fmt.Errorf("failed to decode session data: %w", decodeErr) } if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { if err := sess.Destroy(); err != nil { sess.Release() log.Errorf("failed to destroy session: %v", err) } return nil, ErrSessionIDNotFoundInStore } } return sess, nil }