feat!(middleware/session): re-write session middleware with handler (#3016)

* feat!(middleware/session): re-write session middleware with handler

* test(middleware/session): refactor to IdleTimeout

* fix: lint errors

* test: Save session after setting or deleting raw data in CSRF middleware

* Update middleware/session/middleware.go

Co-authored-by: Renan Bastos <renanbastos.tec@gmail.com>

* fix: mutex and globals order

* feat: Re-Add read lock to session Get method

* feat: Migrate New() to return middleware

* chore: Refactor session middleware to improve session handling

* chore: Private get on store

* chore: Update session middleware to use saveSession instead of save

* chore: Update session middleware to use getSession instead of get

* chore: Remove unused error handler in session middleware config

* chore: Update session middleware to use NewWithStore in CSRF tests

* test: add test

* fix: destroyed session and GHSA-98j2-3j3p-fw2v

* chore: Refactor session_test.go to use newStore() instead of New()

* feat: Improve session middleware test coverage and error handling

This commit improves the session middleware test coverage by adding assertions for the presence of the Set-Cookie header and the token value. It also enhances error handling by checking for the expected number of parts in the Set-Cookie header.

* chore: fix lint issues

* chore: Fix session middleware locking issue and improve error handling

* test: improve middleware test coverage and error handling

* test: Add idle timeout test case to session middleware test

* feat: add GetSession(id string) (*Session, error)

* chore: lint

* docs: Update session middleware docs

* docs: Security Note to examples

* docs: Add recommendation for CSRF protection in session middleware

* chore: markdown lint

* docs: Update session middleware docs

* docs: makrdown lint

* test(middleware/session): Add unit tests for session config.go

* test(middleware/session): Add unit tests for store.go

* test(middleware/session): Add data.go unit tests

* refactor(middleware/session): session tests and add session release test

- Refactor session tests to improve readability and maintainability.
- Add a new test case to ensure proper session release functionality.
- Update session.md

* refactor: session data locking in middleware/session/data.go

* refactor(middleware/session): Add unit test for session middleware store

* test: fix session_test.go and store_test.go unit tests

* refactor(docs): Update session.md with v3 changes to Expiration

* refactor(middleware/session): Improve data pool handling and locking

* chore(middleware/session): TODO for Expiration field in session config

* refactor(middleware/session): Improve session data pool handling and locking

* refactor(middleware/session): Improve session data pool handling and locking

* test(middleware/csrf): add session middleware coverage

* chroe(middleware/session): TODO for unregistered session middleware

* refactor(middleware/session): Update session middleware for v3 changes

* refactor(middleware/session): Update session middleware for v3 changes

* refactor(middleware/session): Update session middleware idle timeout

- Update the default idle timeout for session middleware from 24 hours to 30 minutes.
- Add a note in the session middleware documentation about the importance of the middleware order.

* docws(middleware/session): Add note about IdleTimeout requiring save using legacy approach

* refactor(middleware/session): Update session middleware idle timeout

Update the idle timeout for the session middleware to 30 minutes. This ensures that the session expires after a period of inactivity. The previous value was 24 hours, which is too long for most use cases. This change improves the security and efficiency of the session management.

* docs(middleware/session): Update session middleware idle timeout and configuration

* test(middleware/session): Fix tests for updated panics

* refactor(middleware/session): Update session middleware initialization and saving

* refactor(middleware/session): Remove unnecessary comment about negative IdleTimeout value

* refactor(middleware/session): Update session middleware make NewStore public

* refactor(middleware/session): Update session middleware Set, Get, and Delete methods

Refactor the Set, Get, and Delete methods in the session middleware to use more descriptive parameter names. Instead of using "middlewareContextKey", the methods now use "key" to represent the key of the session value. This improves the readability and clarity of the code.

* feat(middleware/session): AbsoluteTimeout and key any

* fix(middleware/session): locking issues and lint errors

* chore(middleware/session): Regenerate code in data_msgp.go

* refactor(middleware/session): rename GetSessionByID to GetByID

This commit also includes changes to the session_test.go and store_test.go files to add test cases for the new GetByID method.

* docs(middleware/session): AbsoluteTimeout

* refactor(middleware/csrf): Rename Expiration to IdleTimeout

* docs(whats-new): CSRF Rename Expiration to IdleTimeout and remove SessionKey field

* refactor(middleware/session): Rename expirationKeyType to absExpirationKeyType and update related functions

* refactor(middleware/session): rename Test_Session_Save_Absolute to Test_Session_Save_AbsoluteTimeout

* chore(middleware/session): update as per PR comments

* docs(middlware/session): fix indent lint

* fix(middleware/session): Address EfeCtn Comments

* refactor(middleware/session): Move bytesBuffer to it's own pool

* test(middleware/session): add decodeSessionData error coverage

* refactor(middleware/session): Update absolute timeout handling

- Update absolute timeout handling in getSession function
- Set absolute expiration time in getSession function
- Delete expired session in GetByID function

* refactor(session/middleware): fix *Session nil ctx when using Store.GetByID

* refactor(middleware/session): Remove unnecessary line in session_test.go

* fix(middleware/session): *Session lifecycle issues

* docs(middleware/session): Update GetByID method documentation

* docs(middleware/session): Update GetByID method documentation

* docs(middleware/session): markdown lint

* refactor(middleware/session): Simplify error handling in DefaultErrorHandler

* fix( middleware/session/config.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* add ctx releases for the test cases

---------

Co-authored-by: Renan Bastos <renanbastos.tec@gmail.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
Co-authored-by: René <rene@gofiber.io>
pull/3181/head
Jason McNeil 2024-10-25 03:36:30 -03:00 committed by GitHub
parent 298975a982
commit e3232c1505
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2818 additions and 446 deletions

View File

@ -34,7 +34,7 @@ app.Use(csrf.New(csrf.Config{
KeyLookup: "header:X-Csrf-Token",
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
IdleTimeout: 30 * time.Minute,
KeyGenerator: utils.UUIDv4,
Extractor: func(c fiber.Ctx) (string, error) { ... },
}))
@ -106,15 +106,14 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error
| CookieSecure | `bool` | Indicates if the CSRF cookie is secure. | false |
| CookieHTTPOnly | `bool` | Indicates if the CSRF cookie is HTTP-only. | false |
| CookieSameSite | `string` | Value of SameSite cookie. | "Lax" |
| CookieSessionOnly | `bool` | Decides whether the cookie should last for only the browser session. Ignores Expiration if set to true. | false |
| Expiration | `time.Duration` | Expiration is the duration before the CSRF token will expire. | 1 * time.Hour |
| CookieSessionOnly | `bool` | Decides whether the cookie should last for only the browser session. (cookie expires on close). | false |
| IdleTimeout | `time.Duration` | IdleTimeout is the duration of inactivity before the CSRF token will expire. | 30 * time.Minute |
| KeyGenerator | `func() string` | KeyGenerator creates a new CSRF token. | utils.UUID |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler is executed when an error is returned from fiber.Handler. | DefaultErrorHandler |
| Extractor | `func(fiber.Ctx) (string, error)` | Extractor returns the CSRF token. If set, this will be used in place of an Extractor based on KeyLookup. | Extractor based on KeyLookup |
| SingleUseToken | `bool` | SingleUseToken indicates if the CSRF token be destroyed and a new one generated on each use. (See TokenLifecycle) | false |
| Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` |
| Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` |
### Default Config
@ -124,11 +123,10 @@ var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
IdleTimeout: 30 * time.Minute,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: FromHeader(HeaderName),
SessionKey: "csrfToken",
}
```
@ -144,12 +142,11 @@ var ConfigDefault = Config{
CookieSecure: true,
CookieSessionOnly: true,
CookieHTTPOnly: true,
Expiration: 1 * time.Hour,
IdleTimeout: 30 * time.Minute,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: FromHeader(HeaderName),
Session: session.Store,
SessionKey: "csrfToken",
}
```
@ -304,7 +301,7 @@ The Referer header is automatically included in requests by all modern browsers,
## Token Lifecycle
Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time.
Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 30 minutes, and each subsequent request extends the expiration by the idle timeout. The token only expires if the user doesn't make a request for the duration of the idle timeout.
### Token Reuse

View File

@ -2,142 +2,481 @@
id: session
---
# Session
# Session Middleware for [Fiber](https://github.com/gofiber/fiber)
Session middleware for [Fiber](https://github.com/gofiber/fiber).
The `session` middleware provides session management for Fiber applications, utilizing the [Storage](https://github.com/gofiber/storage) package for multi-database support via a unified interface. By default, session data is stored in memory, but custom storage options are easily configurable (see examples below).
As of v3, we recommend using the middleware handler for session management. However, for backward compatibility, v2's session methods are still available, allowing you to continue using the session management techniques from earlier versions of Fiber. Both methods are demonstrated in the examples.
## Table of Contents
- [Migration Guide](#migration-guide)
- [v2 to v3](#v2-to-v3)
- [Types](#types)
- [Config](#config)
- [Middleware](#middleware)
- [Session](#session)
- [Store](#store)
- [Signatures](#signatures)
- [Session Package Functions](#session-package-functions)
- [Config Methods](#config-methods)
- [Middleware Methods](#middleware-methods)
- [Session Methods](#session-methods)
- [Store Methods](#store-methods)
- [Examples](#examples)
- [Middleware Handler (Recommended)](#middleware-handler-recommended)
- [Custom Storage Example](#custom-storage-example)
- [Session Without Middleware Handler](#session-without-middleware-handler)
- [Custom Types in Session Data](#custom-types-in-session-data)
- [Config](#config)
- [Default Config](#default-config)
## Migration Guide
### v2 to v3
- **Function Signature Change**: In v3, the `New` function now returns a middleware handler instead of a `*Store`. To access the store, use the `Store` method on `*Middleware` (obtained from `session.FromContext(c)` in a handler) or use `NewStore` or `NewWithStore`.
- **Session Lifecycle Management**: The `*Store.Save` method no longer releases the instance automatically. You must manually call `sess.Release()` after using the session to manage its lifecycle properly.
- **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with `IdleTimeout` and `AbsoluteTimeout` fields, which explicitly defines the session's idle and absolute timeout periods.
- **Idle Timeout**: The new `IdleTimeout`, handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically.
- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity.
For more details about Fiber v3, see [Whats New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md).
### Migrating v2 to v3 Example (Legacy Approach)
To convert a v2 example to use the v3 legacy approach, follow these steps:
1. **Initialize with Store**: Use `session.NewStore()` to obtain a store.
2. **Retrieve Session**: Access the session store using the `store.Get(c)` method.
3. **Release Session**: Ensure that you call `sess.Release()` after you are done with the session to manage its lifecycle.
:::note
This middleware uses our [Storage](https://github.com/gofiber/storage) package to support various databases through a single interface. The default configuration for this middleware saves data to memory, see the examples below for other databases.
When using the legacy approach, the IdleTimeout will be updated when the session is saved.
:::
#### Example Conversion
**v2 Example:**
```go
store := session.New()
app.Get("/", func(c *fiber.Ctx) error {
sess, err := store.Get(c)
if err != nil {
return err
}
key, ok := sess.Get("key").(string)
if !ok {
return c.SendStatus(fiber.StatusInternalServerError)
}
sess.Set("key", "value")
err = sess.Save()
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return nil
})
```
**v3 Legacy Approach:**
```go
store := session.NewStore()
app.Get("/", func(c fiber.Ctx) error {
sess, err := store.Get(c)
if err != nil {
return err
}
defer sess.Release() // Important: Release the session
key, ok := sess.Get("key").(string)
if !ok {
return c.SendStatus(fiber.StatusInternalServerError)
}
sess.Set("key", "value")
err = sess.Save()
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return nil
})
```
### v3 Example (Recommended Middleware Handler)
Do not call `sess.Release()` when using the middleware handler. `sess.Save()` is also not required, as the middleware automatically saves the session data.
For the recommended approach, use the middleware handler. See the [Middleware Handler (Recommended)](#middleware-handler-recommended) section for details.
## Types
### Config
Defines the configuration options for the session middleware.
```go
type Config struct {
Storage fiber.Storage
Next func(fiber.Ctx) bool
Store *Store
ErrorHandler func(fiber.Ctx, error)
KeyGenerator func() string
KeyLookup string
CookieDomain string
CookiePath string
CookieSameSite string
IdleTimeout time.Duration
AbsoluteTimeout time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieSessionOnly bool
}
```
### Middleware
The `Middleware` struct encapsulates the session middleware configuration and storage, created via `New` or `NewWithStore`.
```go
type Middleware struct {
Session *Session
}
```
### Session
Represents a user session, accessible through `FromContext` or `Store.Get`.
```go
type Session struct {}
```
### Store
Handles session data management and is created using `NewStore`, `NewWithStore` or by accessing the `Store` method of a middleware instance.
```go
type Store struct {
Config
}
```
## Signatures
```go
func New(config ...Config) *Store
func (s *Store) RegisterType(i any)
func (s *Store) Get(c fiber.Ctx) (*Session, error)
func (s *Store) Delete(id string) error
func (s *Store) Reset() error
### Session Package Functions
func (s *Session) Get(key string) any
func (s *Session) Set(key string, val any)
func (s *Session) Delete(key string)
func (s *Session) Destroy() error
func (s *Session) Reset() error
func (s *Session) Regenerate() error
func (s *Session) Save() error
func (s *Session) Fresh() bool
func (s *Session) ID() string
func (s *Session) Keys() []string
func (s *Session) SetExpiry(exp time.Duration)
```go
func New(config ...Config) *Middleware
func NewWithStore(config ...Config) (fiber.Handler, *Store)
func FromContext(c fiber.Ctx) *Middleware
```
:::caution
Storing `any` values are limited to built-ins Go types.
### Config Methods
```go
func DefaultErrorHandler(fiber.Ctx, err error)
```
### Middleware Methods
```go
func (m *Middleware) Set(key string, value any)
func (m *Middleware) Get(key string) any
func (m *Middleware) Delete(key string)
func (m *Middleware) Destroy() error
func (m *Middleware) Reset() error
func (m *Middleware) Store() *Store
```
### Session Methods
```go
func (s *Session) Fresh() bool
func (s *Session) ID() string
func (s *Session) Get(key string) any
func (s *Session) Set(key string, val any)
func (s *Session) Destroy() error
func (s *Session) Regenerate() error
func (s *Session) Release()
func (s *Session) Reset() error
func (s *Session) Save() error
func (s *Session) Keys() []string
func (s *Session) SetIdleTimeout(idleTimeout time.Duration)
```
### Store Methods
```go
func (*Store) RegisterType(i any)
func (s *Store) Get(c fiber.Ctx) (*Session, error)
func (s *Store) GetByID(id string) (*Session, error)
func (s *Store) Reset() error
func (s *Store) Delete(id string) error
```
:::note
#### `GetByID` Method
The `GetByID` method retrieves a session from storage using its session ID. Unlike `Get`, which ties the session to a `fiber.Ctx` (request-response cycle), `GetByID` operates independently of any HTTP context. This makes it ideal for scenarios such as background processing, scheduled tasks, or non-HTTP-related session management.
##### Key Features
- **Context Independence**: Sessions retrieved via `GetByID` are not bound to `fiber.Ctx`. This means the session can be manipulated in contexts that aren't tied to an active HTTP request-response cycle.
- **Background Task Suitability**: Use this method when you need to manage sessions outside of the standard HTTP workflow, such as in scheduled jobs, background tasks, or any non-HTTP context where session data needs to be accessed or modified.
##### Usage Considerations
- **Manual Persistence**: Since there is no associated `fiber.Ctx`, changes made to the session (e.g., modifying data) will **not** automatically be saved to storage. You **must** call `session.Save()` explicitly to persist any updates to storage.
- **No Automatic Cookie Handling**: Any updates made to the session will **not** affect the client-side cookies. If the session changes need to be reflected in the client (e.g., in a future HTTP response), you will need to handle this manually by setting the cookies via other methods.
- **Resource Management**: After using a session retrieved by `GetByID`, you should call `session.Release()` to properly release the session back to the pool and free up resources.
##### Example Use Cases
- **Scheduled Jobs**: Retrieve and update session data periodically without triggering an HTTP request.
- **Background Processing**: Manage sessions for tasks running in the background, such as user inactivity checks or batch processing.
:::
## Examples
Import the middleware package that is part of the Fiber web framework
:::note
**Security Notice**: For robust security, especially during sensitive operations like account changes or transactions, consider using CSRF protection. Fiber provides a [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) that can be used with sessions to prevent CSRF attacks.
:::
:::note
**Middleware Order**: The order of middleware matters. The session middleware should come before any handler or middleware that uses the session (for example, the CSRF middleware).
:::
### Middleware Handler (Recommended)
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/csrf"
"github.com/gofiber/fiber/v3/middleware/session"
)
func main() {
app := fiber.New()
sessionMiddleware, sessionStore := session.NewWithStore()
app.Use(sessionMiddleware)
app.Use(csrf.New(csrf.Config{
Store: sessionStore,
}))
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
name, ok := sess.Get("name").(string)
if !ok {
return c.SendString("Welcome anonymous user!")
}
return c.SendString("Welcome " + name)
})
app.Listen(":3000")
}
```
### Custom Storage Example
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/storage/sqlite3"
"github.com/gofiber/fiber/v3/middleware/csrf"
"github.com/gofiber/fiber/v3/middleware/session"
)
func main() {
app := fiber.New()
storage := sqlite3.New()
sessionMiddleware, sessionStore := session.NewWithStore(session.Config{
Storage: storage,
})
app.Use(sessionMiddleware)
app.Use(csrf.New(csrf.Config{
Store: sessionStore,
}))
app.Listen(":3000")
}
```
### Session Without Middleware Handler
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/csrf"
"github.com/gofiber/fiber/v3/middleware/session"
)
func main() {
app := fiber.New()
sessionStore := session.NewStore()
app.Use(csrf.New(csrf.Config{
Store: sessionStore,
}))
app.Get("/", func(c fiber.Ctx) error {
sess, err := sessionStore.Get(c)
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
defer sess.Release()
name, ok := sess.Get("name").(string)
if !ok {
return c.SendString("Welcome anonymous user!")
}
return c.SendString("Welcome " + name)
})
app.Post("/login", func(c fiber.Ctx) error {
sess, err := sessionStore.Get(c)
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
defer sess.Release()
if !sess.Fresh() {
if err := sess.Regenerate(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
}
sess.Set("name", "John Doe")
err = sess.Save()
if err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendString("Logged in!")
})
app.Listen(":3000")
}
```
### Custom Types in Session Data
Session data can only be of the following types by default:
- `string`
- `int`
- `int8`
- `int16`
- `int32`
- `int64`
- `uint`
- `uint8`
- `uint16`
- `uint32`
- `uint64`
- `bool`
- `float32`
- `float64`
- `[]byte`
- `complex64`
- `complex128`
- `interface{}`
To support other types in session data, you can register custom types. Here is an example of how to register a custom type:
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
)
```
After you initiate your Fiber app, you can use the following possibilities:
type User struct {
Name string
Age int
}
```go
// Initialize default config
// This stores all of your app's sessions
store := session.New()
func main() {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
// Get session from storage
sess, err := store.Get(c)
if err != nil {
panic(err)
}
sessionMiddleware, sessionStore := session.NewWithStore()
sessionStore.RegisterType(User{})
// Get value
name := sess.Get("name")
app.Use(sessionMiddleware)
// Set key/value
sess.Set("name", "john")
// Get all Keys
keys := sess.Keys()
// Delete key
sess.Delete("name")
// Destroy session
if err := sess.Destroy(); err != nil {
panic(err)
}
// Sets a specific expiration for this session
sess.SetExpiry(time.Second * 2)
// Save session
if err := sess.Save(); err != nil {
panic(err)
}
return c.SendString(fmt.Sprintf("Welcome %v", name))
})
app.Listen(":3000")
}
```
## Config
| Property | Type | Description | Default |
|:------------------------|:----------------|:------------------------------------------------------------------------------------------------------------|:----------------------|
| Expiration | `time.Duration` | Allowed session duration. | `24 * time.Hour` |
| Storage | `fiber.Storage` | Storage interface to store the session data. | `memory.New()` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract session id from the request. | `"cookie:session_id"` |
| CookieDomain | `string` | Domain of the cookie. | `""` |
| CookiePath | `string` | Path of the cookie. | `""` |
| CookieSecure | `bool` | Indicates if cookie is secure. | `false` |
| CookieHTTPOnly | `bool` | Indicates if cookie is HTTP only. | `false` |
| CookieSameSite | `string` | Value of SameSite cookie. | `"Lax"` |
| CookieSessionOnly | `bool` | Decides whether cookie should last for only the browser session. Ignores Expiration if set to true. | `false` |
| KeyGenerator | `func() string` | KeyGenerator generates the session key. | `utils.UUIDv4` |
| CookieName (Deprecated) | `string` | Deprecated: Please use KeyLookup. The session name. | `""` |
| Property | Type | Description | Default |
|-----------------------|--------------------------------|--------------------------------------------------------------------------------------------|---------------------------|
| **Storage** | `fiber.Storage` | Defines where session data is stored. | `nil` (in-memory storage) |
| **Next** | `func(c fiber.Ctx) bool` | Function to skip this middleware under certain conditions. | `nil` |
| **ErrorHandler** | `func(c fiber.Ctx, err error)` | Custom error handler for session middleware errors. | `nil` |
| **KeyGenerator** | `func() string` | Function to generate session IDs. | `UUID()` |
| **KeyLookup** | `string` | Key used to store session ID in cookie or header. | `"cookie:session_id"` |
| **CookieDomain** | `string` | The domain scope of the session cookie. | `""` |
| **CookiePath** | `string` | The path scope of the session cookie. | `"/"` |
| **CookieSameSite** | `string` | The SameSite attribute of the session cookie. | `"Lax"` |
| **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `30 * time.Minute` |
| **AbsoluteTimeout** | `time.Duration` | Maximum duration before session expires. | `0` (no expiration) |
| **CookieSecure** | `bool` | Ensures session cookie is only sent over HTTPS. | `false` |
| **CookieHTTPOnly** | `bool` | Ensures session cookie is not accessible to JavaScript (HTTP only). | `true` |
| **CookieSessionOnly** | `bool` | Prevents session cookie from being saved after the session ends (cookie expires on close). | `false` |
## Default Config
```go
var ConfigDefault = Config{
Expiration: 24 * time.Hour,
KeyLookup: "cookie:session_id",
KeyGenerator: utils.UUIDv4,
source: "cookie",
sessionName: "session_id",
session.Config{
Storage: memory.New(),
Next: nil,
Store: nil,
ErrorHandler: nil,
KeyGenerator: utils.UUIDv4,
KeyLookup: "cookie:session_id",
CookieDomain: "",
CookiePath: "",
CookieSameSite: "Lax",
IdleTimeout: 30 * time.Minute,
AbsoluteTimeout: 0,
CookieSecure: false,
CookieHTTPOnly: false,
CookieSessionOnly: false,
}
```
## Constants
```go
const (
SourceCookie Source = "cookie"
SourceHeader Source = "header"
SourceURLQuery Source = "query"
)
```
### Custom Storage/Database
You can use any storage from our [storage](https://github.com/gofiber/storage/) package.
```go
storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3
store := session.New(session.Config{
Storage: storage,
})
```
To use the store, see the [Examples](#examples).

View File

@ -30,6 +30,7 @@ Here's a quick overview of the changes in Fiber `v3`:
- [🧰 Generic functions](#-generic-functions)
- [🧬 Middlewares](#-middlewares)
- [CORS](#cors)
- [CSRF](#csrf)
- [Session](#session)
- [Filesystem](#filesystem)
- [Monitor](#monitor)
@ -316,9 +317,19 @@ Added support for specifying Key length when using `encryptcookie.GenerateKey(le
### Session
:::caution
DRAFT section
:::
The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management.
#### Key Updates
- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration.
- **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`.
- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically.
- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity.
For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide).
### Filesystem
@ -521,6 +532,24 @@ app.Use(cors.New(cors.Config{
}))
```
#### CSRF
- **Field Renaming**: The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes. Update your code as follows:
```go
// Before
app.Use(csrf.New(csrf.Config{
Expiration: 10 * time.Minute,
}))
// After
app.Use(csrf.New(csrf.Config{
IdleTimeout: 10 * time.Minute,
}))
```
- **Session Key Removal**: The `SessionKey` field has been removed from the CSRF middleware configuration. The session key is now an unexported constant within the middleware to avoid potential key collisions in the session store.
#### Filesystem
You need to move filesystem middleware to static middleware due to it has been removed from the core.

View File

@ -78,11 +78,6 @@ type Config struct {
// Optional. Default value "Lax".
CookieSameSite string
// SessionKey is the key used to store the token in the session
//
// Default: "csrfToken"
SessionKey string
// TrustedOrigins is a list of trusted origins for unsafe requests.
// For requests that use the Origin header, the origin must match the
// Host header or one of the TrustedOrigins.
@ -96,10 +91,10 @@ type Config struct {
// Optional. Default: []
TrustedOrigins []string
// Expiration is the duration before csrf token will expire
// IdleTimeout is the duration of time the CSRF token is valid.
//
// Optional. Default: 1 * time.Hour
Expiration time.Duration
// Optional. Default: 30 * time.Minute
IdleTimeout time.Duration
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
@ -127,11 +122,10 @@ var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
IdleTimeout: 30 * time.Minute,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: FromHeader(HeaderName),
SessionKey: "csrfToken",
}
// default ErrorHandler that process return error from fiber.Handler
@ -153,8 +147,8 @@ func configDefault(config ...Config) Config {
if cfg.KeyLookup == "" {
cfg.KeyLookup = ConfigDefault.KeyLookup
}
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
if cfg.IdleTimeout <= 0 {
cfg.IdleTimeout = ConfigDefault.IdleTimeout
}
if cfg.CookieName == "" {
cfg.CookieName = ConfigDefault.CookieName
@ -168,9 +162,6 @@ func configDefault(config ...Config) Config {
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
if cfg.SessionKey == "" {
cfg.SessionKey = ConfigDefault.SessionKey
}
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")

View File

@ -49,10 +49,7 @@ func New(config ...Config) fiber.Handler {
var sessionManager *sessionManager
var storageManager *storageManager
if cfg.Session != nil {
// Register the Token struct in the session store
cfg.Session.RegisterType(Token{})
sessionManager = newSessionManager(cfg.Session, cfg.SessionKey)
sessionManager = newSessionManager(cfg.Session)
} else {
storageManager = newStorageManager(cfg.Storage)
}
@ -220,9 +217,9 @@ func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *se
// createOrExtendTokenInStorage creates or extends the token in the storage
func createOrExtendTokenInStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
if cfg.Session != nil {
sessionManager.setRaw(c, token, dummyValue, cfg.Expiration)
sessionManager.setRaw(c, token, dummyValue, cfg.IdleTimeout)
} else {
storageManager.setRaw(token, dummyValue, cfg.Expiration)
storageManager.setRaw(token, dummyValue, cfg.IdleTimeout)
}
}
@ -237,7 +234,7 @@ func deleteTokenFromStorage(c fiber.Ctx, token string, cfg Config, sessionManage
// Update CSRF cookie
// if expireCookie is true, the cookie will expire immediately
func updateCSRFCookie(c fiber.Ctx, cfg Config, token string) {
setCSRFCookie(c, cfg, token, cfg.Expiration)
setCSRFCookie(c, cfg, token, cfg.IdleTimeout)
}
func expireCSRFCookie(c fiber.Ctx, cfg Config) {

View File

@ -70,7 +70,7 @@ func Test_CSRF_WithSession(t *testing.T) {
t.Parallel()
// session store
store := session.New(session.Config{
store := session.NewStore(session.Config{
KeyLookup: "cookie:_session",
})
@ -156,13 +156,68 @@ func Test_CSRF_WithSession(t *testing.T) {
}
}
// go test -run Test_CSRF_WithSession_Middleware
func Test_CSRF_WithSession_Middleware(t *testing.T) {
t.Parallel()
app := fiber.New()
// session mw
smh, sstore := session.NewWithStore()
// csrf mw
cmh := New(Config{
Session: sstore,
})
app.Use(smh)
app.Use(cmh)
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
sess.Set("hello", "world")
return c.SendStatus(fiber.StatusOK)
})
app.Post("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
if sess.Get("hello") != "world" {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token and session_id
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
csrfTokenParts := strings.Split(string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)), ";")
require.Greater(t, len(csrfTokenParts), 2)
csrfToken := strings.Split(csrfTokenParts[0], "=")[1]
require.NotEmpty(t, csrfToken)
sessionID := strings.Split(csrfTokenParts[1], "=")[1]
require.NotEmpty(t, sessionID)
// Use the CSRF token and session_id
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, csrfToken)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, csrfToken)
ctx.Request.Header.SetCookie("session_id", sessionID)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
}
// go test -run Test_CSRF_ExpiredToken
func Test_CSRF_ExpiredToken(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Second,
IdleTimeout: 1 * time.Second,
}))
app.Post("/", func(c fiber.Ctx) error {
@ -205,7 +260,7 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
t.Parallel()
// session store
store := session.New(session.Config{
store := session.NewStore(session.Config{
KeyLookup: "cookie:_session",
})
@ -229,8 +284,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
// middleware config
config := Config{
Session: store,
Expiration: 1 * time.Second,
Session: store,
IdleTimeout: 1 * time.Second,
}
// middleware
@ -1076,7 +1131,7 @@ func Test_CSRF_DeleteToken_WithSession(t *testing.T) {
t.Parallel()
// session store
store := session.New(session.Config{
store := session.NewStore(session.Config{
KeyLookup: "cookie:_session",
})

View File

@ -10,28 +10,46 @@ import (
type sessionManager struct {
session *session.Store
key string
}
func newSessionManager(s *session.Store, k string) *sessionManager {
type sessionKeyType int
const (
sessionKey sessionKeyType = 0
)
func newSessionManager(s *session.Store) *sessionManager {
// Create new storage handler
sessionManager := &sessionManager{
key: k,
}
sessionManager := new(sessionManager)
if s != nil {
// Use provided storage if provided
sessionManager.session = s
// Register the sessionKeyType and Token type
s.RegisterType(sessionKeyType(0))
s.RegisterType(Token{})
}
return sessionManager
}
// get token from session
func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte {
sess, err := m.session.Get(c)
if err != nil {
return nil
sess := session.FromContext(c)
var token Token
var ok bool
if sess != nil {
token, ok = sess.Get(sessionKey).(Token)
} else {
// Try to get the session from the store
storeSess, err := m.session.Get(c)
if err != nil {
// Handle error
return nil
}
token, ok = storeSess.Get(sessionKey).(Token)
}
token, ok := sess.Get(m.key).(Token)
if ok {
if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) {
return nil
@ -44,25 +62,39 @@ func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte {
// set token in session
func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Duration) {
sess, err := m.session.Get(c)
if err != nil {
return
}
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
sess.Set(m.key, &Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)})
if err := sess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
sess := session.FromContext(c)
if sess != nil {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
sess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)})
} else {
// Try to get the session from the store
storeSess, err := m.session.Get(c)
if err != nil {
// Handle error
return
}
storeSess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)})
if err := storeSess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
}
}
}
// delete token from session
func (m *sessionManager) delRaw(c fiber.Ctx) {
sess, err := m.session.Get(c)
if err != nil {
return
}
sess.Delete(m.key)
if err := sess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
sess := session.FromContext(c)
if sess != nil {
sess.Delete(sessionKey)
} else {
// Try to get the session from the store
storeSess, err := m.session.Get(c)
if err != nil {
// Handle error
return
}
storeSess.Delete(sessionKey)
if err := storeSess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
}
}
}

View File

@ -5,60 +5,98 @@ import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
// Config defines the config for middleware.
// Config defines the configuration for the session middleware.
type Config struct {
// Storage interface to store the session data
// Optional. Default value memory.New()
// Storage interface for storing session data.
//
// Optional. Default: memory.New()
Storage fiber.Storage
// Next defines a function to skip this middleware when it returns true.
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Store defines the session store.
//
// Required.
Store *Store
// ErrorHandler defines a function to handle errors.
//
// Optional. Default: nil
ErrorHandler func(fiber.Ctx, error)
// KeyGenerator generates the session key.
// Optional. Default value utils.UUIDv4
//
// Optional. Default: utils.UUIDv4
KeyGenerator func() string
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract session id from the request.
// Possible values: "header:<name>", "query:<name>" or "cookie:<name>"
// Optional. Default value "cookie:session_id".
// KeyLookup is a string in the format "<source>:<name>" used to extract the session ID from the request.
//
// Possible values: "header:<name>", "query:<name>", "cookie:<name>"
//
// Optional. Default: "cookie:session_id"
KeyLookup string
// Domain of the cookie.
// Optional. Default value "".
// CookieDomain defines the domain of the session cookie.
//
// Optional. Default: ""
CookieDomain string
// Path of the cookie.
// Optional. Default value "".
// CookiePath defines the path of the session cookie.
//
// Optional. Default: ""
CookiePath string
// Value of SameSite cookie.
// Optional. Default value "Lax".
// CookieSameSite specifies the SameSite attribute of the cookie.
//
// Optional. Default: "Lax"
CookieSameSite string
// Source defines where to obtain the session id
// Source defines where to obtain the session ID.
source Source
// The session name
// sessionName is the name of the session.
sessionName string
// Allowed session duration
// Optional. Default value 24 * time.Hour
Expiration time.Duration
// Indicates if cookie is secure.
// Optional. Default value false.
// IdleTimeout defines the maximum duration of inactivity before the session expires.
//
// Note: The idle timeout is updated on each `Save()` call. If a middleware handler is used, `Save()` is called automatically.
//
// Optional. Default: 30 * time.Minute
IdleTimeout time.Duration
// AbsoluteTimeout defines the maximum duration of the session before it expires.
//
// If set to 0, the session will not have an absolute timeout, and will expire after the idle timeout.
//
// Optional. Default: 0
AbsoluteTimeout time.Duration
// CookieSecure specifies if the session cookie should be secure.
//
// Optional. Default: false
CookieSecure bool
// Indicates if cookie is HTTP only.
// Optional. Default value false.
// CookieHTTPOnly specifies if the session cookie should be HTTP-only.
//
// Optional. Default: false
CookieHTTPOnly bool
// Decides whether cookie should last for only the browser sesison.
// Ignores Expiration if set to true
// Optional. Default value false.
// CookieSessionOnly determines if the cookie should expire when the browser session ends.
//
// If true, the cookie will be deleted when the browser is closed.
// Note: This will not delete the session data from the store.
//
// Optional. Default: false
CookieSessionOnly bool
}
// Source represents the type of session ID source.
type Source string
const (
@ -67,28 +105,59 @@ const (
SourceURLQuery Source = "query"
)
// ConfigDefault is the default config
// ConfigDefault provides the default configuration.
var ConfigDefault = Config{
Expiration: 24 * time.Hour,
IdleTimeout: 30 * time.Minute,
KeyLookup: "cookie:session_id",
KeyGenerator: utils.UUIDv4,
source: "cookie",
source: SourceCookie,
sessionName: "session_id",
}
// Helper function to set default values
// DefaultErrorHandler logs the error and sends a 500 status code.
//
// Parameters:
// - c: The Fiber context.
// - err: The error to handle.
//
// Usage:
//
// DefaultErrorHandler(c, err)
func DefaultErrorHandler(c fiber.Ctx, err error) {
log.Errorf("session: %v", err)
if sendErr := c.SendStatus(fiber.StatusInternalServerError); sendErr != nil {
log.Errorf("session: %v", sendErr)
}
}
// configDefault sets default values for the Config struct.
//
// Parameters:
// - config: Variadic parameter to override the default config.
//
// Returns:
// - Config: The configuration with default values set.
//
// Usage:
//
// cfg := configDefault()
// cfg := configDefault(customConfig)
func configDefault(config ...Config) Config {
// Return default config if nothing provided
// Return default config if none provided.
if len(config) < 1 {
return ConfigDefault
}
// Override default config
// Override default config with provided config.
cfg := config[0]
// Set default values
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
// Set default values where necessary.
if cfg.IdleTimeout <= 0 {
cfg.IdleTimeout = ConfigDefault.IdleTimeout
}
// Ensure AbsoluteTimeout is greater than or equal to IdleTimeout.
if cfg.AbsoluteTimeout > 0 && cfg.AbsoluteTimeout < cfg.IdleTimeout {
panic("[session] AbsoluteTimeout must be greater than or equal to IdleTimeout")
}
if cfg.KeyLookup == "" {
cfg.KeyLookup = ConfigDefault.KeyLookup
@ -97,10 +166,11 @@ func configDefault(config ...Config) Config {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
// Parse KeyLookup into source and session name.
selectors := strings.Split(cfg.KeyLookup, ":")
const numSelectors = 2
if len(selectors) != numSelectors {
panic("[session] KeyLookup must in the form of <source>:<name>")
panic("[session] KeyLookup must be in the format '<source>:<name>'")
}
switch Source(selectors[0]) {
case SourceCookie:
@ -110,7 +180,7 @@ func configDefault(config ...Config) Config {
case SourceURLQuery:
cfg.source = SourceURLQuery
default:
panic("[session] source is not supported")
panic("[session] unsupported source in KeyLookup")
}
cfg.sessionName = selectors[1]

View File

@ -0,0 +1,59 @@
package session
import (
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func TestConfigDefault(t *testing.T) {
// Test default config
cfg := configDefault()
require.Equal(t, 30*time.Minute, cfg.IdleTimeout)
require.Equal(t, "cookie:session_id", cfg.KeyLookup)
require.NotNil(t, cfg.KeyGenerator)
require.Equal(t, SourceCookie, cfg.source)
require.Equal(t, "session_id", cfg.sessionName)
}
func TestConfigDefaultWithCustomConfig(t *testing.T) {
// Test custom config
customConfig := Config{
IdleTimeout: 48 * time.Hour,
KeyLookup: "header:custom_session_id",
KeyGenerator: func() string { return "custom_key" },
}
cfg := configDefault(customConfig)
require.Equal(t, 48*time.Hour, cfg.IdleTimeout)
require.Equal(t, "header:custom_session_id", cfg.KeyLookup)
require.NotNil(t, cfg.KeyGenerator)
require.Equal(t, SourceHeader, cfg.source)
require.Equal(t, "custom_session_id", cfg.sessionName)
}
func TestDefaultErrorHandler(t *testing.T) {
// Create a new Fiber app
app := fiber.New()
// Create a new context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// Test DefaultErrorHandler
DefaultErrorHandler(ctx, fiber.ErrInternalServerError)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode())
}
func TestInvalidKeyLookupFormat(t *testing.T) {
require.PanicsWithValue(t, "[session] KeyLookup must be in the format '<source>:<name>'", func() {
configDefault(Config{KeyLookup: "invalid_format"})
})
}
func TestUnsupportedSource(t *testing.T) {
require.PanicsWithValue(t, "[session] unsupported source in KeyLookup", func() {
configDefault(Config{KeyLookup: "unsupported:session_id"})
})
}

View File

@ -8,57 +8,120 @@ import (
//
//go:generate msgp -o=data_msgp.go -tests=true -unexported
type data struct {
Data map[string]any
Data map[any]any
sync.RWMutex `msg:"-"`
}
var dataPool = sync.Pool{
New: func() any {
d := new(data)
d.Data = make(map[string]any)
d.Data = make(map[any]any)
return d
},
}
// acquireData returns a new data object from the pool.
//
// Returns:
// - *data: The data object.
//
// Usage:
//
// d := acquireData()
func acquireData() *data {
return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
obj := dataPool.Get()
if d, ok := obj.(*data); ok {
return d
}
// Handle unexpected type in the pool
panic("unexpected type in data pool")
}
// Reset clears the data map and resets the data object.
//
// Usage:
//
// d.Reset()
func (d *data) Reset() {
d.Lock()
d.Data = make(map[string]any)
d.Unlock()
defer d.Unlock()
d.Data = make(map[any]any)
}
func (d *data) Get(key string) any {
// Get retrieves a value from the data map by key.
//
// Parameters:
// - key: The key to retrieve.
//
// Returns:
// - any: The value associated with the key.
//
// Usage:
//
// value := d.Get("key")
func (d *data) Get(key any) any {
d.RLock()
v := d.Data[key]
d.RUnlock()
return v
defer d.RUnlock()
return d.Data[key]
}
func (d *data) Set(key string, value any) {
// Set updates or creates a new key-value pair in the data map.
//
// Parameters:
// - key: The key to set.
// - value: The value to set.
//
// Usage:
//
// d.Set("key", "value")
func (d *data) Set(key, value any) {
d.Lock()
defer d.Unlock()
d.Data[key] = value
d.Unlock()
}
func (d *data) Delete(key string) {
// Delete removes a key-value pair from the data map.
//
// Parameters:
// - key: The key to delete.
//
// Usage:
//
// d.Delete("key")
func (d *data) Delete(key any) {
d.Lock()
defer d.Unlock()
delete(d.Data, key)
d.Unlock()
}
func (d *data) Keys() []string {
d.Lock()
keys := make([]string, 0, len(d.Data))
// Keys retrieves all keys in the data map.
//
// Returns:
// - []any: A slice of all keys in the data map.
//
// Usage:
//
// keys := d.Keys()
func (d *data) Keys() []any {
d.RLock()
defer d.RUnlock()
keys := make([]any, 0, len(d.Data))
for k := range d.Data {
keys = append(keys, k)
}
d.Unlock()
return keys
}
// Len returns the number of key-value pairs in the data map.
//
// Returns:
// - int: The number of key-value pairs.
//
// Usage:
//
// length := d.Len()
func (d *data) Len() int {
d.RLock()
defer d.RUnlock()
return len(d.Data)
}

View File

@ -24,36 +24,6 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) {
return
}
switch msgp.UnsafeString(field) {
case "Data":
var zb0002 uint32
zb0002, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err, "Data")
return
}
if z.Data == nil {
z.Data = make(map[string]interface{}, zb0002)
} else if len(z.Data) > 0 {
for key := range z.Data {
delete(z.Data, key)
}
}
for zb0002 > 0 {
zb0002--
var za0001 string
var za0002 interface{}
za0001, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "Data")
return
}
za0002, err = dc.ReadIntf()
if err != nil {
err = msgp.WrapError(err, "Data", za0001)
return
}
z.Data[za0001] = za0002
}
default:
err = dc.Skip()
if err != nil {
@ -66,48 +36,22 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) {
}
// EncodeMsg implements msgp.Encodable
func (z *data) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 1
// write "Data"
err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61)
func (z data) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 0
_ = z
err = en.Append(0x80)
if err != nil {
return
}
err = en.WriteMapHeader(uint32(len(z.Data)))
if err != nil {
err = msgp.WrapError(err, "Data")
return
}
for za0001, za0002 := range z.Data {
err = en.WriteString(za0001)
if err != nil {
err = msgp.WrapError(err, "Data")
return
}
err = en.WriteIntf(za0002)
if err != nil {
err = msgp.WrapError(err, "Data", za0001)
return
}
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *data) MarshalMsg(b []byte) (o []byte, err error) {
func (z data) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 1
// string "Data"
o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61)
o = msgp.AppendMapHeader(o, uint32(len(z.Data)))
for za0001, za0002 := range z.Data {
o = msgp.AppendString(o, za0001)
o, err = msgp.AppendIntf(o, za0002)
if err != nil {
err = msgp.WrapError(err, "Data", za0001)
return
}
}
// map header, size 0
_ = z
o = append(o, 0x80)
return
}
@ -129,36 +73,6 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) {
return
}
switch msgp.UnsafeString(field) {
case "Data":
var zb0002 uint32
zb0002, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Data")
return
}
if z.Data == nil {
z.Data = make(map[string]interface{}, zb0002)
} else if len(z.Data) > 0 {
for key := range z.Data {
delete(z.Data, key)
}
}
for zb0002 > 0 {
var za0001 string
var za0002 interface{}
zb0002--
za0001, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Data")
return
}
za0002, bts, err = msgp.ReadIntfBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Data", za0001)
return
}
z.Data[za0001] = za0002
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
@ -172,13 +86,7 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) {
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *data) Msgsize() (s int) {
s = 1 + 5 + msgp.MapHeaderSize
if z.Data != nil {
for za0001, za0002 := range z.Data {
_ = za0002
s += msgp.StringPrefixSize + len(za0001) + msgp.GuessSize(za0002)
}
}
func (z data) Msgsize() (s int) {
s = 1
return
}

View File

@ -0,0 +1,204 @@
package session
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKeys(t *testing.T) {
t.Parallel()
// Test case: Empty data
t.Run("Empty data", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
keys := d.Keys()
require.Empty(t, keys, "Expected no keys in empty data")
})
// Test case: Single key
t.Run("Single key", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
keys := d.Keys()
require.Len(t, keys, 1, "Expected one key")
require.Contains(t, keys, "key1", "Expected key1 to be present")
})
// Test case: Multiple keys
t.Run("Multiple keys", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
keys := d.Keys()
require.Len(t, keys, 3, "Expected three keys")
require.Contains(t, keys, "key1", "Expected key1 to be present")
require.Contains(t, keys, "key2", "Expected key2 to be present")
require.Contains(t, keys, "key3", "Expected key3 to be present")
})
// Test case: Concurrent access
t.Run("Concurrent access", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
done := make(chan bool)
go func() {
keys := d.Keys()
assert.Len(t, keys, 3, "Expected three keys")
done <- true
}()
go func() {
keys := d.Keys()
assert.Len(t, keys, 3, "Expected three keys")
done <- true
}()
<-done
<-done
})
}
func TestData_Len(t *testing.T) {
t.Parallel()
// Test case: Empty data
t.Run("Empty data", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
length := d.Len()
require.Equal(t, 0, length, "Expected length to be 0 for empty data")
})
// Test case: Single key
t.Run("Single key", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
length := d.Len()
require.Equal(t, 1, length, "Expected length to be 1 when one key is set")
})
// Test case: Multiple keys
t.Run("Multiple keys", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
length := d.Len()
require.Equal(t, 3, length, "Expected length to be 3 when three keys are set")
})
// Test case: Concurrent access
t.Run("Concurrent access", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
done := make(chan bool, 2) // Buffered channel with size 2
go func() {
length := d.Len()
assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access")
done <- true
}()
go func() {
length := d.Len()
assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access")
done <- true
}()
<-done
<-done
})
}
func TestData_Get(t *testing.T) {
t.Parallel()
// Test case: Non-existent key
t.Run("Non-existent key", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
value := d.Get("non-existent-key")
require.Nil(t, value, "Expected nil for non-existent key")
})
// Test case: Existing key
t.Run("Existing key", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
value := d.Get("key1")
require.Equal(t, "value1", value, "Expected value1 for key1")
})
}
func TestData_Reset(t *testing.T) {
t.Parallel()
// Test case: Reset data
t.Run("Reset data", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Reset()
require.Empty(t, d.Data, "Expected data map to be empty after reset")
})
}
func TestData_Delete(t *testing.T) {
t.Parallel()
// Test case: Delete existing key
t.Run("Delete existing key", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Delete("key1")
value := d.Get("key1")
require.Nil(t, value, "Expected nil for deleted key")
})
// Test case: Delete non-existent key
t.Run("Delete non-existent key", func(t *testing.T) {
t.Parallel()
d := acquireData()
defer dataPool.Put(d)
defer d.Reset()
d.Delete("non-existent-key")
// No assertion needed, just ensure no panic or error
})
}

View File

@ -0,0 +1,301 @@
// Package session provides session management middleware for Fiber.
// This middleware handles user sessions, including storing session data in the store.
package session
import (
"errors"
"sync"
"github.com/gofiber/fiber/v3"
)
// Middleware holds session data and configuration.
type Middleware struct {
Session *Session
ctx fiber.Ctx
config Config
mu sync.RWMutex
destroyed bool
}
// Context key for session middleware lookup.
type middlewareKey int
const (
// middlewareContextKey is the key used to store the *Middleware in the context locals.
middlewareContextKey middlewareKey = iota
)
var (
// ErrTypeAssertionFailed occurs when a type assertion fails.
ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware")
// Pool for reusing middleware instances.
middlewarePool = &sync.Pool{
New: func() any {
return &Middleware{}
},
}
)
// New initializes session middleware with optional configuration.
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - fiber.Handler: The Fiber handler for the session middleware.
//
// Usage:
//
// app.Use(session.New())
//
// Usage:
//
// app.Use(session.New())
func New(config ...Config) fiber.Handler {
if len(config) > 0 {
handler, _ := NewWithStore(config[0])
return handler
}
handler, _ := NewWithStore()
return handler
}
// NewWithStore creates session middleware with an optional custom store.
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - fiber.Handler: The Fiber handler for the session middleware.
// - *Store: The session store.
//
// Usage:
//
// handler, store := session.NewWithStore()
func NewWithStore(config ...Config) (fiber.Handler, *Store) {
cfg := configDefault(config...)
if cfg.Store == nil {
cfg.Store = NewStore(cfg)
}
handler := func(c fiber.Ctx) error {
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Acquire session middleware
m := acquireMiddleware()
m.initialize(c, cfg)
stackErr := c.Next()
m.mu.RLock()
destroyed := m.destroyed
m.mu.RUnlock()
if !destroyed {
m.saveSession()
}
releaseMiddleware(m)
return stackErr
}
return handler, cfg.Store
}
// initialize sets up middleware for the request.
func (m *Middleware) initialize(c fiber.Ctx, cfg Config) {
m.mu.Lock()
defer m.mu.Unlock()
session, err := cfg.Store.getSession(c)
if err != nil {
panic(err) // handle or log this error appropriately in production
}
m.config = cfg
m.Session = session
m.ctx = c
c.Locals(middlewareContextKey, m)
}
// saveSession handles session saving and error management after the response.
func (m *Middleware) saveSession() {
if err := m.Session.saveSession(); err != nil {
if m.config.ErrorHandler != nil {
m.config.ErrorHandler(m.ctx, err)
} else {
DefaultErrorHandler(m.ctx, err)
}
}
releaseSession(m.Session)
}
// acquireMiddleware retrieves a middleware instance from the pool.
func acquireMiddleware() *Middleware {
m, ok := middlewarePool.Get().(*Middleware)
if !ok {
panic(ErrTypeAssertionFailed.Error())
}
return m
}
// releaseMiddleware resets and returns middleware to the pool.
//
// Parameters:
// - m: The middleware object to release.
//
// Usage:
//
// releaseMiddleware(m)
func releaseMiddleware(m *Middleware) {
m.mu.Lock()
m.config = Config{}
m.Session = nil
m.ctx = nil
m.destroyed = false
m.mu.Unlock()
middlewarePool.Put(m)
}
// FromContext returns the Middleware from the Fiber context.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Middleware: The middleware object if found, otherwise nil.
//
// Usage:
//
// m := session.FromContext(c)
func FromContext(c fiber.Ctx) *Middleware {
m, ok := c.Locals(middlewareContextKey).(*Middleware)
if !ok {
return nil
}
return m
}
// Set sets a key-value pair in the session.
//
// Parameters:
// - key: The key to set.
// - value: The value to set.
//
// Usage:
//
// m.Set("key", "value")
func (m *Middleware) Set(key, value any) {
m.mu.Lock()
defer m.mu.Unlock()
m.Session.Set(key, value)
}
// Get retrieves a value from the session by key.
//
// Parameters:
// - key: The key to retrieve.
//
// Returns:
// - any: The value associated with the key.
//
// Usage:
//
// value := m.Get("key")
func (m *Middleware) Get(key any) any {
m.mu.RLock()
defer m.mu.RUnlock()
return m.Session.Get(key)
}
// Delete removes a key-value pair from the session.
//
// Parameters:
// - key: The key to delete.
//
// Usage:
//
// m.Delete("key")
func (m *Middleware) Delete(key any) {
m.mu.Lock()
defer m.mu.Unlock()
m.Session.Delete(key)
}
// Destroy destroys the session.
//
// Returns:
// - error: An error if the destruction fails.
//
// Usage:
//
// err := m.Destroy()
func (m *Middleware) Destroy() error {
m.mu.Lock()
defer m.mu.Unlock()
err := m.Session.Destroy()
m.destroyed = true
return err
}
// Fresh checks if the session is fresh.
//
// Returns:
// - bool: True if the session is fresh, otherwise false.
//
// Usage:
//
// isFresh := m.Fresh()
func (m *Middleware) Fresh() bool {
return m.Session.Fresh()
}
// ID returns the session ID.
//
// Returns:
// - string: The session ID.
//
// Usage:
//
// id := m.ID()
func (m *Middleware) ID() string {
return m.Session.ID()
}
// Reset resets the session.
//
// Returns:
// - error: An error if the reset fails.
//
// Usage:
//
// err := m.Reset()
func (m *Middleware) Reset() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.Session.Reset()
}
// Store returns the session store.
//
// Returns:
// - *Store: The session store.
//
// Usage:
//
// store := m.Store()
func (m *Middleware) Store() *Store {
return m.config.Store
}

View File

@ -0,0 +1,469 @@
package session
import (
"strings"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_Session_Middleware(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/get", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
value, ok := sess.Get("key").(string)
if !ok {
return c.Status(fiber.StatusNotFound).SendString("key not found")
}
return c.SendString("value=" + value)
})
app.Post("/set", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// get a value from the body
value := c.FormValue("value")
sess.Set("key", value)
return c.SendStatus(fiber.StatusOK)
})
app.Post("/delete", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
sess.Delete("key")
return c.SendStatus(fiber.StatusOK)
})
app.Post("/reset", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if err := sess.Reset(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/destroy", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if err := sess.Destroy(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/fresh", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Reset the session to make it fresh
if err := sess.Reset(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if sess.Fresh() {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusInternalServerError)
})
// Test GET, SET, DELETE, RESET, DESTROY by sending requests to the respective routes
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
h := app.Handler()
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "key not found", string(ctx.Response.Body()))
// Test POST /set
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/set")
ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type
ctx.Request.SetBodyString("value=hello")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /get to check if the value was set
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value=hello", string(ctx.Response.Body()))
// Test POST /delete to delete the value
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/delete")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /get to check if the value was deleted
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
require.Equal(t, "key not found", string(ctx.Response.Body()))
// Test POST /reset to reset the session
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/reset")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// verify we have a new session token
newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /destroy to destroy the session
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/destroy")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Verify the session cookie is set to expire
setCookieHeader := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.Contains(t, setCookieHeader, "expires=")
cookieParts := strings.Split(setCookieHeader, ";")
expired := false
for _, part := range cookieParts {
if strings.Contains(part, "expires=") {
part = strings.TrimSpace(part)
expiryDateStr := strings.TrimPrefix(part, "expires=")
// Correctly parse the date with "GMT" timezone
expiryDate, err := time.Parse(time.RFC1123, strings.TrimSpace(expiryDateStr))
require.NoError(t, err)
if expiryDate.Before(time.Now()) {
expired = true
break
}
}
}
require.True(t, expired, "Session cookie should be expired")
// Sleep so that the session expires
time.Sleep(1 * time.Second)
// Test GET /get to check if the session was destroyed
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
// check that we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
parts := strings.Split(newToken, ";")
require.Greater(t, len(parts), 1)
valueParts := strings.Split(parts[0], "=")
require.Greater(t, len(valueParts), 1)
newToken = valueParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /fresh to check if the session is fresh
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/fresh")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// check that we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
}
func Test_Session_NewWithStore(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
return c.SendString("value=" + id)
})
app.Post("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
c.Cookie(&fiber.Cookie{
Name: "session_id",
Value: id,
})
return nil
})
h := app.Handler()
// Test GET request without cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test GET request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value="+token, string(ctx.Response.Body()))
}
func Test_Session_FromSession(t *testing.T) {
t.Parallel()
app := fiber.New()
sess := FromContext(app.AcquireCtx(&fasthttp.RequestCtx{}))
require.Nil(t, sess)
app.Use(New())
}
func Test_Session_WithConfig(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Get("key") == "value"
},
IdleTimeout: 1 * time.Second,
KeyLookup: "cookie:session_id_test",
KeyGenerator: func() string {
return "test"
},
source: "cookie_test",
sessionName: "session_id_test",
}))
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
return c.SendString("value=" + id)
})
app.Get("/isFresh", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess.Fresh() {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusInternalServerError)
})
app.Post("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
c.Cookie(&fiber.Cookie{
Name: "session_id_test",
Value: id,
})
return nil
})
h := app.Handler()
// Test GET request without cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test GET request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test POST request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id_test", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request without cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request with wrong key
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request with wrong value
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id_test", "wrong")
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Check idle timeout not expired
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
ctx.Request.SetRequestURI("/isFresh")
h(ctx)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode())
// Test idle timeout
time.Sleep(1200 * time.Millisecond)
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
ctx.Request.SetRequestURI("/isFresh")
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}
func Test_Session_Next(t *testing.T) {
t.Parallel()
var (
doNext bool
muNext sync.RWMutex
)
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
muNext.RLock()
defer muNext.RUnlock()
return doNext
},
}))
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
id := sess.ID()
return c.SendString("value=" + id)
})
h := app.Handler()
// Test with Next returning false
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test with Next returning true
muNext.Lock()
doNext = true
muNext.Unlock()
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode())
}
func Test_Session_Middleware_Store(t *testing.T) {
t.Parallel()
app := fiber.New()
handler, sessionStore := NewWithStore()
app.Use(handler)
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
st := sess.Store()
if st != sessionStore {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
// Test GET request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}

View File

@ -12,95 +12,177 @@ import (
"github.com/valyala/fasthttp"
)
// Session represents a user session.
type Session struct {
ctx fiber.Ctx // fiber context
config *Store // store configuration
data *data // key value data
byteBuffer *bytes.Buffer // byte buffer for the en- and decode
id string // session id
exp time.Duration // expiration of this session
mu sync.RWMutex // Mutex to protect non-data fields
fresh bool // if new session
ctx fiber.Ctx // fiber context
config *Store // store configuration
data *data // key value data
id string // session id
idleTimeout time.Duration // idleTimeout of this session
mu sync.RWMutex // Mutex to protect non-data fields
fresh bool // if new session
}
type absExpirationKeyType int
const (
// sessionIDContextKey is the key used to store the session ID in the context locals.
absExpirationKey absExpirationKeyType = iota
)
// Session pool for reusing byte buffers.
var byteBufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
var sessionPool = sync.Pool{
New: func() any {
return new(Session)
return &Session{}
},
}
// acquireSession returns a new Session from the pool.
//
// Returns:
// - *Session: The session object.
//
// Usage:
//
// s := acquireSession()
func acquireSession() *Session {
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
if s.data == nil {
s.data = acquireData()
}
if s.byteBuffer == nil {
s.byteBuffer = new(bytes.Buffer)
}
s.fresh = true
return s
}
// Release releases the session back to the pool.
//
// This function should be called after the session is no longer needed.
// This function is used to reduce the number of allocations and
// to improve the performance of the session store.
//
// The session should not be used after calling this function.
//
// Important: The Release function should only be used when accessing the session directly,
// for example, when you have called func (s *Session) Get(ctx) to get the session.
// It should not be used when using the session with a *Middleware handler in the request
// call stack, as the middleware will still need to access the session.
//
// Usage:
//
// sess := session.Get(ctx)
// defer sess.Release()
func (s *Session) Release() {
if s == nil {
return
}
releaseSession(s)
}
func releaseSession(s *Session) {
s.mu.Lock()
s.id = ""
s.exp = 0
s.idleTimeout = 0
s.ctx = nil
s.config = nil
if s.data != nil {
s.data.Reset()
}
if s.byteBuffer != nil {
s.byteBuffer.Reset()
}
s.mu.Unlock()
sessionPool.Put(s)
}
// Fresh is true if the current session is new
// Fresh returns whether the session is new
//
// Returns:
// - bool: True if the session is fresh, otherwise false.
//
// Usage:
//
// isFresh := s.Fresh()
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}
// ID returns the session id
// ID returns the session ID
//
// Returns:
// - string: The session ID.
//
// Usage:
//
// id := s.ID()
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
// Get will return the value
func (s *Session) Get(key string) any {
// Better safe than sorry
// Get returns the value associated with the given key.
//
// Parameters:
// - key: The key to retrieve.
//
// Returns:
// - any: The value associated with the key.
//
// Usage:
//
// value := s.Get("key")
func (s *Session) Get(key any) any {
if s.data == nil {
return nil
}
return s.data.Get(key)
}
// Set will update or create a new key value
func (s *Session) Set(key string, val any) {
// Better safe than sorry
// Set updates or creates a new key-value pair in the session.
//
// Parameters:
// - key: The key to set.
// - val: The value to set.
//
// Usage:
//
// s.Set("key", "value")
func (s *Session) Set(key, val any) {
if s.data == nil {
return
}
s.data.Set(key, val)
}
// Delete will delete the value
func (s *Session) Delete(key string) {
// Better safe than sorry
// Delete removes the key-value pair from the session.
//
// Parameters:
// - key: The key to delete.
//
// Usage:
//
// s.Delete("key")
func (s *Session) Delete(key any) {
if s.data == nil {
return
}
s.data.Delete(key)
}
// Destroy will delete the session from Storage and expire session cookie
// Destroy deletes the session from storage and expires the session cookie.
//
// Returns:
// - error: An error if the destruction fails.
//
// Usage:
//
// err := s.Destroy()
func (s *Session) Destroy() error {
// Better safe than sorry
if s.data == nil {
return nil
}
@ -121,7 +203,14 @@ func (s *Session) Destroy() error {
return nil
}
// Regenerate generates a new session id and delete the old one from Storage
// Regenerate generates a new session id and deletes the old one from storage.
//
// Returns:
// - error: An error if the regeneration fails.
//
// Usage:
//
// err := s.Regenerate()
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()
@ -137,7 +226,14 @@ func (s *Session) Regenerate() error {
return nil
}
// Reset generates a new session id, deletes the old one from storage, and resets the associated data
// Reset generates a new session id, deletes the old one from storage, and resets the associated data.
//
// Returns:
// - error: An error if the reset fails.
//
// Usage:
//
// err := s.Reset()
func (s *Session) Reset() error {
// Reset local data
if s.data != nil {
@ -147,12 +243,8 @@ func (s *Session) Reset() error {
s.mu.Lock()
defer s.mu.Unlock()
// Reset byte buffer
if s.byteBuffer != nil {
s.byteBuffer.Reset()
}
// Reset expiration
s.exp = 0
s.idleTimeout = 0
// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
@ -168,75 +260,102 @@ func (s *Session) Reset() error {
return nil
}
// refresh generates a new session, and set session.fresh to be true
// refresh generates a new session, and sets session.fresh to be true.
func (s *Session) refresh() {
s.id = s.config.KeyGenerator()
s.fresh = true
}
// Save will update the storage and client cookie
// Save saves the session data and updates the cookie
//
// sess.Save() will save the session data to the storage and update the
// client cookie, and it will release the session after saving.
// Note: If the session is being used in the handler, calling Save will have
// no effect and the session will automatically be saved when the handler returns.
//
// It's not safe to use the session after calling Save().
// Returns:
// - error: An error if the save operation fails.
//
// Usage:
//
// err := s.Save()
func (s *Session) Save() error {
// Better safe than sorry
if s.ctx == nil {
return s.saveSession()
}
// If the session is being used in the handler, it should not be saved
if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok {
if m.Session == s {
// Session is in use, so we do nothing and return
return nil
}
}
return s.saveSession()
}
// saveSession encodes session data to saves it to storage.
func (s *Session) saveSession() error {
if s.data == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
// Set idleTimeout if not already set
if s.idleTimeout <= 0 {
s.idleTimeout = s.config.IdleTimeout
}
// Update client cookie
s.setSession()
// Convert data to bytes
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
// Encode session data
s.data.RLock()
encodedBytes, err := s.encodeSessionData()
s.data.RUnlock()
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}
s.mu.Unlock()
// Release session
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)
return nil
return s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout)
}
// Keys will retrieve all keys in current session
func (s *Session) Keys() []string {
// Keys retrieves all keys in the current session.
//
// Returns:
// - []string: A slice of all keys in the session.
//
// Usage:
//
// keys := s.Keys()
func (s *Session) Keys() []any {
if s.data == nil {
return []string{}
return []any{}
}
return s.data.Keys()
}
// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
// SetIdleTimeout used when saving the session on the next call to `Save()`.
//
// Parameters:
// - idleTimeout: The duration for the idle timeout.
//
// Usage:
//
// s.SetIdleTimeout(time.Hour)
func (s *Session) SetIdleTimeout(idleTimeout time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
s.idleTimeout = idleTimeout
}
func (s *Session) setSession() {
if s.ctx == nil {
return
}
if s.config.source == SourceHeader {
s.ctx.Request().Header.SetBytesV(s.config.sessionName, []byte(s.id))
s.ctx.Response().Header.SetBytesV(s.config.sessionName, []byte(s.id))
@ -249,8 +368,8 @@ func (s *Session) setSession() {
// Cookies are also session cookies if they do not specify the Expires or Max-Age attribute.
// refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
if !s.config.CookieSessionOnly {
fcookie.SetMaxAge(int(s.exp.Seconds()))
fcookie.SetExpire(time.Now().Add(s.exp))
fcookie.SetMaxAge(int(s.idleTimeout.Seconds()))
fcookie.SetExpire(time.Now().Add(s.idleTimeout))
}
fcookie.SetSecure(s.config.CookieSecure)
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
@ -269,6 +388,10 @@ func (s *Session) setSession() {
}
func (s *Session) delSession() {
if s.ctx == nil {
return
}
if s.config.source == SourceHeader {
s.ctx.Request().Header.Del(s.config.sessionName)
s.ctx.Response().Header.Del(s.config.sessionName)
@ -299,12 +422,92 @@ func (s *Session) delSession() {
}
}
// decodeSessionData decodes the session data from raw bytes.
// decodeSessionData decodes session data from raw bytes
//
// Parameters:
// - rawData: The raw byte data to decode.
//
// Returns:
// - error: An error if the decoding fails.
//
// Usage:
//
// err := s.decodeSessionData(rawData)
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
defer byteBufferPool.Put(byteBuffer)
defer byteBuffer.Reset()
_, _ = byteBuffer.Write(rawData)
decCache := gob.NewDecoder(byteBuffer)
if err := decCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}
// encodeSessionData encodes session data to raw bytes
//
// Parameters:
// - rawData: The raw byte data to encode.
//
// Returns:
// - error: An error if the encoding fails.
//
// Usage:
//
// err := s.encodeSessionData(rawData)
func (s *Session) encodeSessionData() ([]byte, error) {
byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
defer byteBufferPool.Put(byteBuffer)
defer byteBuffer.Reset()
encCache := gob.NewEncoder(byteBuffer)
if err := encCache.Encode(&s.data.Data); err != nil {
return nil, fmt.Errorf("failed to encode session data: %w", err)
}
// Copy the bytes
// Copy the data in buffer
encodedBytes := make([]byte, byteBuffer.Len())
copy(encodedBytes, byteBuffer.Bytes())
return encodedBytes, nil
}
// absExpiration returns the session absolute expiration time or a zero time if not set.
//
// Returns:
// - time.Time: The session absolute expiration time. Zero time if not set.
//
// Usage:
//
// expiration := s.absExpiration()
func (s *Session) absExpiration() time.Time {
absExpiration, ok := s.Get(absExpirationKey).(time.Time)
if ok {
return absExpiration
}
return time.Time{}
}
// isAbsExpired returns true if the session is expired.
//
// If the session has an absolute expiration time set, this function will return true if the
// current time is after the absolute expiration time.
//
// Returns:
// - bool: True if the session is expired, otherwise false.
func (s *Session) isAbsExpired() bool {
absExpiration := s.absExpiration()
return !absExpiration.IsZero() && time.Now().After(absExpiration)
}
// setAbsoluteExpiration sets the absolute session expiration time.
//
// Parameters:
// - expiration: The session expiration time.
//
// Usage:
//
// s.setExpiration(time.Now().Add(time.Hour))
func (s *Session) setAbsExpiration(absExpiration time.Time) {
s.Set(absExpirationKey, absExpiration)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
@ -17,14 +18,13 @@ func Test_Session(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Get a new session
sess, err := store.Get(ctx)
@ -33,6 +33,7 @@ func Test_Session(t *testing.T) {
token := sess.ID()
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
@ -46,7 +47,7 @@ func Test_Session(t *testing.T) {
// get keys
keys := sess.Keys()
require.Equal(t, []string{}, keys)
require.Equal(t, []any{}, keys)
// get value
name := sess.Get("name")
@ -60,7 +61,7 @@ func Test_Session(t *testing.T) {
require.Equal(t, "john", name)
keys = sess.Keys()
require.Equal(t, []string{"name"}, keys)
require.Equal(t, []any{"name"}, keys)
// delete key
sess.Delete("name")
@ -71,7 +72,7 @@ func Test_Session(t *testing.T) {
// get keys
keys = sess.Keys()
require.Equal(t, []string{}, keys)
require.Equal(t, []any{}, keys)
// get id
id := sess.ID()
@ -81,6 +82,9 @@ func Test_Session(t *testing.T) {
err = sess.Save()
require.NoError(t, err)
// release the session
sess.Release()
// release the context
app.ReleaseCtx(ctx)
// requesting entirely new context to prevent falsy tests
@ -93,6 +97,8 @@ func Test_Session(t *testing.T) {
// this id should be randomly generated as session key was deleted
require.Len(t, sess.ID(), 36)
sess.Release()
// when we use the original session for the second time
// the session be should be same if the session is not expired
app.ReleaseCtx(ctx)
@ -102,6 +108,7 @@ func Test_Session(t *testing.T) {
// request the server with the old session
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.False(t, sess.Fresh())
require.Equal(t, sess.id, id)
@ -112,7 +119,7 @@ func Test_Session_Types(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
@ -186,6 +193,7 @@ func Test_Session_Types(t *testing.T) {
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
@ -277,6 +285,8 @@ func Test_Session_Types(t *testing.T) {
require.True(t, ok)
require.Equal(t, vcomplex128, vcomplex128Result)
sess.Release()
app.ReleaseCtx(ctx)
}
@ -284,7 +294,7 @@ func Test_Session_Types(t *testing.T) {
func Test_Session_Store_Reset(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
@ -304,6 +314,7 @@ func Test_Session_Store_Reset(t *testing.T) {
require.NoError(t, store.Reset())
id := sess.ID()
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
@ -311,11 +322,187 @@ func Test_Session_Store_Reset(t *testing.T) {
// make sure the session is recreated
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.True(t, sess.Fresh())
require.Nil(t, sess.Get("hello"))
}
func Test_Session_KeyTypes(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
type Person struct {
Name string
}
type unexportedKey int
// register non-default types
store.RegisterType(Person{})
store.RegisterType(unexportedKey(0))
type unregisteredKeyType int
type unregisteredValueType int
// verify unregistered keys types are not allowed
var (
unregisteredKey unregisteredKeyType
unregisteredValue unregisteredValueType
)
sess.Set(unregisteredKey, "test")
err = sess.Save()
require.Error(t, err)
sess.Delete(unregisteredKey)
err = sess.Save()
require.NoError(t, err)
sess.Set("abc", unregisteredValue)
err = sess.Save()
require.Error(t, err)
sess.Delete("abc")
err = sess.Save()
require.NoError(t, err)
require.NoError(t, sess.Reset())
var (
kbool = true
kstring = "str"
kint = 13
kint8 int8 = 13
kint16 int16 = 13
kint32 int32 = 13
kint64 int64 = 13
kuint uint = 13
kuint8 uint8 = 13
kuint16 uint16 = 13
kuint32 uint32 = 13
kuint64 uint64 = 13
kuintptr uintptr = 13
kbyte byte = 'k'
krune = 'k'
kfloat32 float32 = 13
kfloat64 float64 = 13
kcomplex64 complex64 = 13
kcomplex128 complex128 = 13
kuser = Person{Name: "John"}
kunexportedKey = unexportedKey(13)
)
var (
vbool = true
vstring = "str"
vint = 13
vint8 int8 = 13
vint16 int16 = 13
vint32 int32 = 13
vint64 int64 = 13
vuint uint = 13
vuint8 uint8 = 13
vuint16 uint16 = 13
vuint32 uint32 = 13
vuint64 uint64 = 13
vuintptr uintptr = 13
vbyte byte = 'k'
vrune = 'k'
vfloat32 float32 = 13
vfloat64 float64 = 13
vcomplex64 complex64 = 13
vcomplex128 complex128 = 13
vuser = Person{Name: "John"}
vunexportedKey = unexportedKey(13)
)
keys := []any{
kbool,
kstring,
kint,
kint8,
kint16,
kint32,
kint64,
kuint,
kuint8,
kuint16,
kuint32,
kuint64,
kuintptr,
kbyte,
krune,
kfloat32,
kfloat64,
kcomplex64,
kcomplex128,
kuser,
kunexportedKey,
}
values := []any{
vbool,
vstring,
vint,
vint8,
vint16,
vint32,
vint64,
vuint,
vuint8,
vuint16,
vuint32,
vuint64,
vuintptr,
vbyte,
vrune,
vfloat32,
vfloat64,
vcomplex64,
vcomplex128,
vuser,
vunexportedKey,
}
// loop test all key value pairs
for i, key := range keys {
sess.Set(key, values[i])
}
id := sess.ID()
ctx.Request().Header.SetCookie(store.sessionName, id)
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie(store.sessionName, id)
// get session
sess, err = store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
require.False(t, sess.Fresh())
// loop test all key value pairs
for i, key := range keys {
// get value
result := sess.Get(key)
require.Equal(t, values[i], result)
}
}
// go test -run Test_Session_Save
func Test_Session_Save(t *testing.T) {
t.Parallel()
@ -323,7 +510,7 @@ func Test_Session_Save(t *testing.T) {
t.Run("save to cookie", func(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
@ -338,12 +525,13 @@ func Test_Session_Save(t *testing.T) {
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
})
t.Run("save to header", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
store := NewStore(Config{
KeyLookup: "header:session_id",
})
// fiber instance
@ -363,10 +551,11 @@ func Test_Session_Save(t *testing.T) {
require.NoError(t, err)
require.Equal(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
require.Equal(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
sess.Release()
})
}
func Test_Session_Save_Expiration(t *testing.T) {
func Test_Session_Save_IdleTimeout(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
@ -374,7 +563,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
const sessionDuration = 5 * time.Second
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
@ -391,12 +580,13 @@ func Test_Session_Save_Expiration(t *testing.T) {
token := sess.ID()
// expire this session in 5 seconds
sess.SetExpiry(sessionDuration)
sess.SetIdleTimeout(sessionDuration)
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
@ -409,6 +599,8 @@ func Test_Session_Save_Expiration(t *testing.T) {
// just to make sure the session has been expired
time.Sleep(sessionDuration + (10 * time.Millisecond))
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
@ -416,12 +608,97 @@ func Test_Session_Save_Expiration(t *testing.T) {
// here you should get a new session
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.Nil(t, sess.Get("name"))
require.NotEqual(t, sess.ID(), token)
})
}
func Test_Session_Save_AbsoluteTimeout(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
t.Parallel()
const absoluteTimeout = 1 * time.Second
// session store
store := NewStore(Config{
IdleTimeout: absoluteTimeout,
AbsoluteTimeout: absoluteTimeout,
})
// force change to IdleTimeout
store.Config.IdleTimeout = 10 * time.Second
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
token := sess.ID()
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you need to get the old session yet
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Equal(t, "john", sess.Get("name"))
// just to make sure the session has been expired
time.Sleep(absoluteTimeout + (100 * time.Millisecond))
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you should get a new session
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Nil(t, sess.Get("name"))
require.NotEqual(t, sess.ID(), token)
require.True(t, sess.Fresh())
require.IsType(t, time.Time{}, sess.Get(absExpirationKey))
token = sess.ID()
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
// just to make sure the session has been expired
time.Sleep(absoluteTimeout + (100 * time.Millisecond))
// try to get expired session by id
sess, err = store.GetByID(token)
require.Error(t, err)
require.ErrorIs(t, err, ErrSessionIDNotFoundInStore)
require.Nil(t, sess)
})
}
// go test -run Test_Session_Destroy
func Test_Session_Destroy(t *testing.T) {
t.Parallel()
@ -429,7 +706,7 @@ func Test_Session_Destroy(t *testing.T) {
t.Run("destroy from cookie", func(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
@ -438,6 +715,7 @@ func Test_Session_Destroy(t *testing.T) {
// get session
sess, err := store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
sess.Set("name", "fenny")
@ -449,7 +727,7 @@ func Test_Session_Destroy(t *testing.T) {
t.Run("destroy from header", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
store := NewStore(Config{
KeyLookup: "header:session_id",
})
// fiber instance
@ -467,6 +745,7 @@ func Test_Session_Destroy(t *testing.T) {
id := sess.ID()
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
@ -475,6 +754,7 @@ func Test_Session_Destroy(t *testing.T) {
ctx.Request().Header.Set(store.sessionName, id)
sess, err = store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
err = sess.Destroy()
require.NoError(t, err)
@ -487,19 +767,19 @@ func Test_Session_Destroy(t *testing.T) {
func Test_Session_Custom_Config(t *testing.T) {
t.Parallel()
store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }})
require.Equal(t, time.Hour, store.Expiration)
store := NewStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }})
require.Equal(t, time.Hour, store.IdleTimeout)
require.Equal(t, "very random", store.KeyGenerator())
store = New(Config{Expiration: 0})
require.Equal(t, ConfigDefault.Expiration, store.Expiration)
store = NewStore(Config{IdleTimeout: 0})
require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout)
}
// go test -run Test_Session_Cookie
func Test_Session_Cookie(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
@ -511,15 +791,19 @@ func Test_Session_Cookie(t *testing.T) {
require.NoError(t, err)
require.NoError(t, sess.Save())
sess.Release()
// cookie should be set on Save ( even if empty data )
require.Len(t, ctx.Response().Header.PeekCookie(store.sessionName), 84)
cookie := ctx.Response().Header.PeekCookie(store.sessionName)
require.NotNil(t, cookie)
require.Regexp(t, `^session_id=[a-f0-9\-]{36}; max-age=\d+; path=/; SameSite=Lax$`, string(cookie))
}
// go test -run Test_Session_Cookie_In_Response
// Regression: https://github.com/gofiber/fiber/pull/1191
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
t.Parallel()
store := New()
store := NewStore()
app := fiber.New()
// fiber context
@ -534,8 +818,11 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
id := sess.ID()
require.NoError(t, sess.Save())
sess.Release()
sess, err = store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
sess.Set("name", "john")
require.True(t, sess.Fresh())
require.Equal(t, id, sess.ID()) // session id should be the same
@ -548,7 +835,7 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
// Regression: https://github.com/gofiber/fiber/issues/1365
func Test_Session_Deletes_Single_Key(t *testing.T) {
t.Parallel()
store := New()
store := NewStore()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
@ -559,6 +846,7 @@ func Test_Session_Deletes_Single_Key(t *testing.T) {
sess.Set("id", "1")
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
@ -568,11 +856,13 @@ func Test_Session_Deletes_Single_Key(t *testing.T) {
sess.Delete("id")
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.False(t, sess.Fresh())
require.Nil(t, sess.Get("id"))
@ -587,7 +877,7 @@ func Test_Session_Reset(t *testing.T) {
app := fiber.New()
// session store
store := New()
store := NewStore()
t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
t.Parallel()
@ -609,6 +899,7 @@ func Test_Session_Reset(t *testing.T) {
err = freshSession.Save()
require.NoError(t, err)
freshSession.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
@ -630,7 +921,7 @@ func Test_Session_Reset(t *testing.T) {
// Check that the session data has been reset
keys := acquiredSession.Keys()
require.Equal(t, []string{}, keys)
require.Equal(t, []any{}, keys)
// Set a new value for 'name' and check that it's updated
acquiredSession.Set("name", "john")
@ -641,6 +932,8 @@ func Test_Session_Reset(t *testing.T) {
err = acquiredSession.Save()
require.NoError(t, err)
acquiredSession.Release()
// Check that the session id is not in the header or cookie anymore
require.Equal(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
require.Equal(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
@ -658,7 +951,7 @@ func Test_Session_Regenerate(t *testing.T) {
t.Run("set fresh to be true when regenerating a session", func(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// a random session uuid
originalSessionUUIDString := ""
// fiber context
@ -674,6 +967,8 @@ func Test_Session_Regenerate(t *testing.T) {
err = freshSession.Save()
require.NoError(t, err)
freshSession.Release()
// release the context
app.ReleaseCtx(ctx)
@ -686,6 +981,7 @@ func Test_Session_Regenerate(t *testing.T) {
// as the session is in the storage, session.fresh should be false
acquiredSession, err := store.Get(ctx)
require.NoError(t, err)
defer acquiredSession.Release()
require.False(t, acquiredSession.Fresh())
err = acquiredSession.Regenerate()
@ -704,7 +1000,7 @@ func Test_Session_Regenerate(t *testing.T) {
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
func Benchmark_Session(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
app, store := fiber.New(), NewStore()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
@ -715,12 +1011,14 @@ func Benchmark_Session(b *testing.B) {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
}
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
store := NewStore(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
@ -733,6 +1031,8 @@ func Benchmark_Session(b *testing.B) {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
}
})
}
@ -740,7 +1040,7 @@ func Benchmark_Session(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
func Benchmark_Session_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
app, store := fiber.New(), NewStore()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
@ -751,6 +1051,9 @@ func Benchmark_Session_Parallel(b *testing.B) {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
app.ReleaseCtx(c)
}
})
@ -758,7 +1061,7 @@ func Benchmark_Session_Parallel(b *testing.B) {
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
store := NewStore(Config{
Storage: memory.New(),
})
b.ReportAllocs()
@ -771,6 +1074,9 @@ func Benchmark_Session_Parallel(b *testing.B) {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
app.ReleaseCtx(c)
}
})
@ -780,7 +1086,7 @@ func Benchmark_Session_Parallel(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
func Benchmark_Session_Asserted(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
app, store := fiber.New(), NewStore()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
@ -793,12 +1099,13 @@ func Benchmark_Session_Asserted(b *testing.B) {
sess.Set("john", "doe")
err = sess.Save()
require.NoError(b, err)
sess.Release()
}
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
store := NewStore(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
@ -813,6 +1120,7 @@ func Benchmark_Session_Asserted(b *testing.B) {
sess.Set("john", "doe")
err = sess.Save()
require.NoError(b, err)
sess.Release()
}
})
}
@ -820,7 +1128,7 @@ func Benchmark_Session_Asserted(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
app, store := fiber.New(), NewStore()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
@ -832,6 +1140,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
require.NoError(b, err)
sess.Set("john", "doe")
require.NoError(b, sess.Save())
sess.Release()
app.ReleaseCtx(c)
}
})
@ -839,7 +1148,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
store := NewStore(Config{
Storage: memory.New(),
})
b.ReportAllocs()
@ -853,6 +1162,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
require.NoError(b, err)
sess.Set("john", "doe")
require.NoError(b, sess.Save())
sess.Release()
app.ReleaseCtx(c)
}
})
@ -863,7 +1173,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()
store := NewStore()
var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
@ -877,7 +1187,7 @@ func Test_Session_Concurrency(t *testing.T) {
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.Get(localCtx)
sess, err := store.getSession(localCtx)
if err != nil {
errChan <- err
return
@ -901,6 +1211,9 @@ func Test_Session_Concurrency(t *testing.T) {
return
}
// release the session
sess.Release()
// Release the context
app.ReleaseCtx(localCtx)
@ -917,6 +1230,7 @@ func Test_Session_Concurrency(t *testing.T) {
errChan <- err
return
}
defer sess.Release()
// Get the value
name := sess.Get("name")
@ -963,3 +1277,42 @@ func Test_Session_Concurrency(t *testing.T) {
require.NoError(t, err)
}
}
func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) {
// Initialize a new store with default config
store := NewStore()
// Create a new Fiber app
app := fiber.New()
// Generate a fake session ID
sessionID := uuid.New().String()
// Store invalid session data to simulate decode error
err := store.Storage.Set(sessionID, []byte("invalid data"), 0)
require.NoError(t, err, "Failed to set invalid session data")
// Create a new request context
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
// Set the session ID in cookies
c.Request().Header.SetCookie(store.sessionName, sessionID)
// Attempt to get the session
_, err = store.Get(c)
require.Error(t, err, "Expected error due to invalid session data, but got nil")
// Check that the error message is as expected
require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message")
// Check that the error is as expected
require.ErrorContains(t, err, "failed to decode session data", "Unexpected error")
// Attempt to get the session by ID
_, err = store.GetByID(sessionID)
require.Error(t, err, "Expected error due to invalid session data, but got nil")
// Check that the error message is as expected
require.ErrorContains(t, err, "failed to decode session data", "Unexpected error")
}

View File

@ -4,14 +4,20 @@ import (
"encoding/gob"
"errors"
"fmt"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")
var (
ErrEmptySessionID = errors.New("session ID cannot be empty")
ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware")
ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store")
)
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
@ -26,7 +32,17 @@ type Store struct {
}
// New creates a new session store with the provided configuration.
func New(config ...Config) *Store {
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - *Store: The session store.
//
// Usage:
//
// store := session.New()
func NewStore(config ...Config) *Store {
// Set default config
cfg := configDefault(config...)
@ -34,18 +50,75 @@ func New(config ...Config) *Store {
cfg.Storage = memory.New()
}
return &Store{
store := &Store{
Config: cfg,
}
if cfg.AbsoluteTimeout > 0 {
store.RegisterType(absExpirationKey)
store.RegisterType(time.Time{})
}
return store
}
// RegisterType registers a custom type for encoding/decoding into any storage provider.
//
// Parameters:
// - i: The custom type to register.
//
// Usage:
//
// store.RegisterType(MyCustomType{})
func (*Store) RegisterType(i any) {
gob.Register(i)
}
// Get retrieves or creates a session for the given context.
// Get will get/create a session.
//
// This function will return an ErrSessionAlreadyLoadedByMiddleware if
// the session is already loaded by the middleware.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Session: The session object.
// - error: An error if the session retrieval fails or if the session is already loaded by the middleware.
//
// Usage:
//
// sess, err := store.Get(c)
// if err != nil {
// // handle error
// }
func (s *Store) Get(c fiber.Ctx) (*Session, error) {
// If session is already loaded in the context,
// it should not be loaded again
_, ok := c.Locals(middlewareContextKey).(*Middleware)
if ok {
return nil, ErrSessionAlreadyLoadedByMiddleware
}
return s.getSession(c)
}
// getSession retrieves a session based on the context.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Session: The session object.
// - error: An error if the session retrieval fails.
//
// Usage:
//
// sess, err := store.getSession(c)
// if err != nil {
// // handle error
// }
func (s *Store) getSession(c fiber.Ctx) (*Session, error) {
var rawData []byte
var err error
@ -79,7 +152,6 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) {
sess := acquireSession()
sess.mu.Lock()
defer sess.mu.Unlock()
sess.ctx = c
sess.config = s
@ -89,16 +161,40 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) {
// Decode session data if found
if rawData != nil {
sess.data.Lock()
defer sess.data.Unlock()
if err := sess.decodeSessionData(rawData); err != nil {
err := sess.decodeSessionData(rawData)
sess.data.Unlock()
if err != nil {
sess.mu.Unlock()
sess.Release()
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
}
sess.mu.Unlock()
if fresh && s.AbsoluteTimeout > 0 {
sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout))
} else if sess.isAbsExpired() {
if err := sess.Reset(); err != nil {
return nil, fmt.Errorf("failed to reset session: %w", err)
}
sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout))
}
return sess, nil
}
// getSessionID returns the session ID from cookies, headers, or query string.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - string: The session ID.
//
// Usage:
//
// id := store.getSessionID(c)
func (s *Store) getSessionID(c fiber.Ctx) string {
id := c.Cookies(s.sessionName)
if len(id) > 0 {
@ -123,14 +219,113 @@ func (s *Store) getSessionID(c fiber.Ctx) string {
}
// Reset deletes all sessions from the storage.
//
// Returns:
// - error: An error if the reset operation fails.
//
// Usage:
//
// err := store.Reset()
// if err != nil {
// // handle error
// }
func (s *Store) Reset() error {
return s.Storage.Reset()
}
// Delete deletes a session by its ID.
//
// Parameters:
// - id: The unique identifier of the session.
//
// Returns:
// - error: An error if the deletion fails or if the session ID is empty.
//
// Usage:
//
// err := store.Delete(id)
// if err != nil {
// // handle error
// }
func (s *Store) Delete(id string) error {
if id == "" {
return ErrEmptySessionID
}
return s.Storage.Delete(id)
}
// GetByID retrieves a session by its ID from the storage.
// If the session is not found, it returns nil and an error.
//
// Unlike session middleware methods, this function does not automatically:
//
// - Load the session into the request context.
//
// - Save the session data to the storage or update the client cookie.
//
// Important Notes:
//
// - The session object returned by GetByID does not have a context associated with it.
//
// - When using this method alongside session middleware, there is a potential for collisions,
// so be mindful of interactions between manually retrieved sessions and middleware-managed sessions.
//
// - If you modify a session returned by GetByID, you must call session.Save() to persist the changes.
//
// - When you are done with the session, you should call session.Release() to release the session back to the pool.
//
// Parameters:
// - id: The unique identifier of the session.
//
// Returns:
// - *Session: The session object if found, otherwise nil.
// - error: An error if the session retrieval fails or if the session ID is empty.
//
// Usage:
//
// sess, err := store.GetByID(id)
// if err != nil {
// // handle error
// }
func (s *Store) GetByID(id string) (*Session, error) {
if id == "" {
return nil, ErrEmptySessionID
}
rawData, err := s.Storage.Get(id)
if err != nil {
return nil, err
}
if rawData == nil {
return nil, ErrSessionIDNotFoundInStore
}
sess := acquireSession()
sess.mu.Lock()
sess.config = s
sess.id = id
sess.fresh = false
sess.data.Lock()
decodeErr := sess.decodeSessionData(rawData)
sess.data.Unlock()
sess.mu.Unlock()
if decodeErr != nil {
sess.Release()
return nil, fmt.Errorf("failed to decode session data: %w", decodeErr)
}
if s.AbsoluteTimeout > 0 {
if sess.isAbsExpired() {
if err := sess.Destroy(); err != nil {
sess.Release()
log.Errorf("failed to destroy session: %v", err)
}
return nil, ErrSessionIDNotFoundInStore
}
}
return sess, nil
}

View File

@ -20,9 +20,10 @@ func Test_Store_getSessionID(t *testing.T) {
t.Run("from cookie", func(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, expectedID)
@ -33,11 +34,12 @@ func Test_Store_getSessionID(t *testing.T) {
t.Run("from header", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
store := NewStore(Config{
KeyLookup: "header:session_id",
})
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set header
ctx.Request().Header.Set(store.sessionName, expectedID)
@ -48,11 +50,12 @@ func Test_Store_getSessionID(t *testing.T) {
t.Run("from url query", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
store := NewStore(Config{
KeyLookup: "query:session_id",
})
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set url parameter
ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.sessionName, expectedID))
@ -73,9 +76,10 @@ func Test_Store_Get(t *testing.T) {
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
t.Parallel()
// session store
store := New()
store := NewStore()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, unexpectedID)
@ -93,10 +97,11 @@ func Test_Store_DeleteSession(t *testing.T) {
// fiber instance
app := fiber.New()
// session store
store := New()
store := NewStore()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Create a new session
session, err := store.Get(ctx)
@ -116,3 +121,105 @@ func Test_Store_DeleteSession(t *testing.T) {
// The session ID should be different now, because the old session was deleted
require.NotEqual(t, sessionID, session.ID())
}
func TestStore_Get_SessionAlreadyLoaded(t *testing.T) {
// Create a new Fiber app
app := fiber.New()
// Create a new context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Mock middleware and set it in the context
middleware := &Middleware{}
ctx.Locals(middlewareContextKey, middleware)
// Create a new store
store := &Store{}
// Call the Get method
sess, err := store.Get(ctx)
// Assert that the error is ErrSessionAlreadyLoadedByMiddleware
require.Nil(t, sess)
require.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err)
}
func TestStore_Delete(t *testing.T) {
// Create a new store
store := NewStore()
t.Run("delete with empty session ID", func(t *testing.T) {
err := store.Delete("")
require.Error(t, err)
require.Equal(t, ErrEmptySessionID, err)
})
t.Run("delete non-existing session", func(t *testing.T) {
err := store.Delete("non-existing-session-id")
require.NoError(t, err)
})
}
func Test_Store_GetByID(t *testing.T) {
t.Parallel()
// Create a new store
store := NewStore()
t.Run("empty session ID", func(t *testing.T) {
t.Parallel()
sess, err := store.GetByID("")
require.Error(t, err)
require.Nil(t, sess)
require.Equal(t, ErrEmptySessionID, err)
})
t.Run("non-existent session ID", func(t *testing.T) {
t.Parallel()
sess, err := store.GetByID("non-existent-session-id")
require.Error(t, err)
require.Nil(t, sess)
require.Equal(t, ErrSessionIDNotFoundInStore, err)
})
t.Run("valid session ID", func(t *testing.T) {
t.Parallel()
app := fiber.New()
// Create a new session
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
session, err := store.Get(ctx)
defer session.Release()
defer app.ReleaseCtx(ctx)
require.NoError(t, err)
// Save the session ID
sessionID := session.ID()
// Save the session
err = session.Save()
require.NoError(t, err)
// Retrieve the session by ID
retrievedSession, err := store.GetByID(sessionID)
require.NoError(t, err)
require.NotNil(t, retrievedSession)
require.Equal(t, sessionID, retrievedSession.ID())
// Call Save on the retrieved session
retrievedSession.Set("key", "value")
err = retrievedSession.Save()
require.NoError(t, err)
// Call Other Session methods
require.Equal(t, "value", retrievedSession.Get("key"))
require.False(t, retrievedSession.Fresh())
require.NoError(t, retrievedSession.Reset())
require.NoError(t, retrievedSession.Destroy())
require.IsType(t, []any{}, retrievedSession.Keys())
require.NoError(t, retrievedSession.Regenerate())
require.NotPanics(t, func() {
retrievedSession.Release()
})
})
}