♻️ v3: fix!: ContextKey collisions (#2781)

* fix: ContextKey collisions

* fix(logger): lint error

* docs(csrf): fix potential range error in example
pull/2787/head
Jason McNeil 2024-01-04 04:44:45 -04:00 committed by GitHub
parent f37238e494
commit 2954e3bbae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 242 additions and 167 deletions

6
ctx.go
View File

@ -33,8 +33,12 @@ const (
// maxParams defines the maximum number of parameters per route.
const maxParams = 30
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
const userContextKey = "__local_user_context__"
const userContextKey contextKey = 0 // __local_user_context__
type DefaultCtx struct {
app *App // Reference to *App

View File

@ -895,13 +895,23 @@ func (c *Ctx) Locals(key interface{}, value ...interface{}) interface{}
```
```go title="Example"
// key is an unexported type for keys defined in this package.
// This prevents collisions with keys defined in other packages.
type key int
// userKey is the key for user.User values in Contexts. It is
// unexported; clients use user.NewContext and user.FromContext
// instead of using this key directly.
var userKey key
app.Use(func(c *fiber.Ctx) error {
c.Locals("user", "admin")
c.Locals(userKey, "admin")
return c.Next()
})
app.Get("/admin", func(c *fiber.Ctx) error {
if c.Locals("user") == "admin" {
if c.Locals(userKey) == "admin" {
return c.Status(fiber.StatusOK).SendString("Welcome, admin!")
}
return c.SendStatus(fiber.StatusForbidden)

View File

@ -10,6 +10,8 @@ Basic Authentication middleware for [Fiber](https://github.com/gofiber/fiber) th
```go
func New(config Config) fiber.Handler
func UsernameFromContext(c *fiber.Ctx) string
func PasswordFromContext(c *fiber.Ctx) string
```
## Examples
@ -53,11 +55,20 @@ app.Use(basicauth.New(basicauth.Config{
Unauthorized: func(c *fiber.Ctx) error {
return c.SendFile("./unauthorized.html")
},
ContextUsername: "_user",
ContextPassword: "_pass",
}))
```
Getting the username and password
```go
func handler(c *fiber.Ctx) error {
username := basicauth.UsernameFromContext(c)
password := basicauth.PasswordFromContext(c)
log.Printf("Username: %s Password: %s", username, password)
return c.SendString("Hello, " + username)
}
```
## Config
| Property | Type | Description | Default |
@ -67,8 +78,6 @@ app.Use(basicauth.New(basicauth.Config{
| Realm | `string` | Realm is a string to define the realm attribute of BasicAuth. The realm identifies the system to authenticate against and can be used by clients to save credentials. | `"Restricted"` |
| Authorizer | `func(string, string) bool` | Authorizer defines a function to check the credentials. It will be called with a username and password and is expected to return true or false to indicate approval. | `nil` |
| Unauthorized | `fiber.Handler` | Unauthorized defines the response body for unauthorized responses. | `nil` |
| ContextUsername | `string` | ContextUsername is the key to store the username in Locals. | `"username"` |
| ContextPassword | `string` | ContextPassword is the key to store the password in Locals. | `"password"` |
## Default Config
@ -79,7 +88,5 @@ var ConfigDefault = Config{
Realm: "Restricted",
Authorizer: nil,
Unauthorized: nil,
ContextUsername: "username",
ContextPassword: "password",
}
```

View File

@ -10,7 +10,7 @@ This middleware can be used with or without a user session and offers two token
## Token Generation
CSRF tokens are generated on 'safe' requests and when the existing token has expired or hasn't been set yet. If `SingleUseToken` is `true`, a new token is generated after each use. Retrieve the CSRF token using `c.Locals(contextKey)`, where `contextKey` is defined in the configuration.
CSRF tokens are generated on 'safe' requests and when the existing token has expired or hasn't been set yet. If `SingleUseToken` is `true`, a new token is generated after each use. Retrieve the CSRF token using `csrf.TokenFromContext(c)`.
## Security Considerations
@ -82,7 +82,8 @@ Using `SingleUseToken` comes with usability trade-offs and is not enabled by def
When the authorization status changes, the CSRF token MUST be deleted, and a new one generated. This can be done by calling `handler.DeleteToken(c)`.
```go
if handler, ok := app.AcquireCtx(ctx).Locals(ConfigDefault.HandlerContextKey).(*CSRFHandler); ok {
handler := csrf.HandlerFromContext(ctx)
if handler != nil {
if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil {
// handle error
}
@ -101,6 +102,10 @@ It's important to note that the token is sent as a header on every request. If y
```go
func New(config ...Config) fiber.Handler
func TokenFromContext(c *fiber.Ctx) string
func HandlerFromContext(c *fiber.Ctx) *Handler
func (h *Handler) DeleteToken(c *fiber.Ctx) error
```
## Examples
@ -135,6 +140,36 @@ app.Use(csrf.New(csrf.Config{
KeyLookup will be ignored if Extractor is explicitly set.
:::
Getting the CSRF token in a handler:
```go
```go
func handler(c *fiber.Ctx) error {
handler := csrf.HandlerFromContext(c)
token := csrf.TokenFromContext(c)
if handler == nil {
panic("csrf middleware handler not registered")
}
cfg := handler.Config
if cfg == nil {
panic("csrf middleware handler has no config")
}
if !strings.Contains(cfg.KeyLookup, ":") {
panic("invalid KeyLookup format")
}
formKey := strings.Split(cfg.KeyLookup, ":")[1]
tmpl := fmt.Sprintf(`<form action="/post" method="POST">
<input type="hidden" name="%s" value="%s">
<input type="text" name="message">
<input type="submit" value="Submit">
</form>`, formKey, token)
c.Set("Content-Type", "text/html")
return c.SendString(tmpl)
}
```
## Config
| Property | Type | Description | Default |
@ -152,15 +187,10 @@ KeyLookup will be ignored if Extractor is explicitly set.
| SingleUseToken | `bool` | SingleUseToken indicates if the CSRF token be destroyed and a new one generated on each use. (See TokenLifecycle) | false |
| Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` |
| Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "fiber.csrf.token" |
| ContextKey | `string` | Context key to store the generated CSRF token into the context. If left empty, the token will not be stored in the context. | "" |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" |
| KeyGenerator | `func() string` | KeyGenerator creates a new CSRF token. | utils.UUID |
| CookieExpires | `time.Duration` (Deprecated) | Deprecated: Please use Expiration. | 0 |
| Cookie | `*fiber.Cookie` (Deprecated) | Deprecated: Please use Cookie* related fields. | `nil` |
| TokenLookup | `string` (Deprecated) | Deprecated: Please use KeyLookup. | "" |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler is executed when an error is returned from fiber.Handler. | DefaultErrorHandler |
| Extractor | `func(*fiber.Ctx) (string, error)` | Extractor returns the CSRF token. If set, this will be used in place of an Extractor based on KeyLookup. | Extractor based on KeyLookup |
| HandlerContextKey | `string` | HandlerContextKey is used to store the CSRF Handler into context. | "fiber.csrf.handler" |
### Default Config
@ -173,8 +203,7 @@ var ConfigDefault = Config{
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
SessionKey: "fiber.csrf.token",
HandlerContextKey: "fiber.csrf.handler",
SessionKey: "csrfToken",
}
```
@ -194,8 +223,7 @@ var ConfigDefault = Config{
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
Session: session.Store,
SessionKey: "fiber.csrf.token",
HandlerContextKey: "fiber.csrf.handler",
SessionKey: "csrfToken",
}
```

View File

@ -21,6 +21,7 @@ Safe HTTP methods — `GET`, `HEAD`, `OPTIONS` and `TRACE` — should not modify
```go
func New(config ...Config) fiber.Handler
func IsEarlyData(c fiber.Ctx) bool
```
## Examples

View File

@ -12,6 +12,8 @@ Refer to https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-ke
```go
func New(config ...Config) fiber.Handler
func IsFromCache(c fiber.Ctx) bool
func WasPutToCache(c fiber.Ctx) bool
```
## Examples

View File

@ -10,6 +10,7 @@ Key auth middleware provides a key based authentication.
```go
func New(config ...Config) fiber.Handler
func TokenFromContext(c fiber.Ctx) string
```
## Examples
@ -213,15 +214,14 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000
## Config
| Property | Type | Description | Default |
|:---------------|:-----------------------------------------|:-----------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| Property | Type | Description | Default |
|:---------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(*fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
| ContextKey | `string` | Context key to store the bearer token from the token into context. | "token" |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(*fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
## Default Config
@ -238,6 +238,5 @@ var ConfigDefault = Config{
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
AuthScheme: "Bearer",
ContextKey: "token",
}
```

View File

@ -10,6 +10,7 @@ RequestID middleware for [Fiber](https://github.com/gofiber/fiber) that adds an
```go
func New(config ...Config) fiber.Handler
func FromContext(c *fiber.Ctx) string
```
## Examples
@ -38,6 +39,16 @@ app.Use(requestid.New(requestid.Config{
}))
```
Getting the request ID
```go
func handler(c *fiber.Ctx) error {
id := requestid.FromContext(c)
log.Printf("Request ID: %s", id)
return c.SendString("Hello, World!")
}
```
## Config
| Property | Type | Description | Default |
@ -45,7 +56,6 @@ app.Use(requestid.New(requestid.Config{
| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| Header | `string` | Header is the header key where to get/set the unique request ID. | "X-Request-ID" |
| Generator | `func() string` | Generator defines a function to generate the unique identifier. | utils.UUID |
| ContextKey | `interface{}` | ContextKey defines the key used when storing the request ID in the locals for a specific request. | "requestid" |
## Default Config
The default config uses a fast UUID generator which will expose the number of
@ -57,6 +67,5 @@ var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: utils.UUID,
ContextKey: "requestid",
}
```

View File

@ -8,6 +8,16 @@ import (
"github.com/gofiber/utils/v2"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
usernameKey contextKey = iota
passwordKey
)
// New creates a new middleware handler
func New(config Config) fiber.Handler {
// Set default config
@ -49,8 +59,8 @@ func New(config Config) fiber.Handler {
password := creds[index+1:]
if cfg.Authorizer(username, password) {
c.Locals(cfg.ContextUsername, username)
c.Locals(cfg.ContextPassword, password)
c.Locals(usernameKey, username)
c.Locals(passwordKey, password)
return c.Next()
}
@ -58,3 +68,23 @@ func New(config Config) fiber.Handler {
return cfg.Unauthorized(c)
}
}
// UsernameFromContext returns the username found in the context
// returns an empty string if the username does not exist
func UsernameFromContext(c fiber.Ctx) string {
username, ok := c.Locals(usernameKey).(string)
if !ok {
return ""
}
return username
}
// PasswordFromContext returns the password found in the context
// returns an empty string if the password does not exist
func PasswordFromContext(c fiber.Ctx) string {
password, ok := c.Locals(passwordKey).(string)
if !ok {
return ""
}
return password
}

View File

@ -39,8 +39,8 @@ func Test_Middleware_BasicAuth(t *testing.T) {
}))
app.Get("/testauth", func(c fiber.Ctx) error {
username := c.Locals("username").(string) //nolint:errcheck, forcetypeassert // not needed
password := c.Locals("password").(string) //nolint:errcheck, forcetypeassert // not needed
username := UsernameFromContext(c)
password := PasswordFromContext(c)
return c.SendString(username + password)
})

View File

@ -40,27 +40,15 @@ type Config struct {
//
// Optional. Default: nil
Unauthorized fiber.Handler
// ContextUser is the key to store the username in Locals
//
// Optional. Default: "username"
ContextUsername string
// ContextPass is the key to store the password in Locals
//
// Optional. Default: "password"
ContextPassword string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Authorizer: nil,
Unauthorized: nil,
ContextUsername: "username",
ContextPassword: "password",
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Authorizer: nil,
Unauthorized: nil,
}
// Helper function to set default values
@ -95,11 +83,5 @@ func configDefault(config ...Config) Config {
return c.SendStatus(fiber.StatusUnauthorized)
}
}
if cfg.ContextUsername == "" {
cfg.ContextUsername = ConfigDefault.ContextUsername
}
if cfg.ContextPassword == "" {
cfg.ContextPassword = ConfigDefault.ContextPassword
}
return cfg
}

View File

@ -86,29 +86,14 @@ type Config struct {
// SessionKey is the key used to store the token in the session
//
// Default: "fiber.csrf.token"
// Default: "csrfToken"
SessionKey string
// Context key to store generated CSRF token into context.
// If left empty, token will not be stored in context.
//
// Optional. Default: ""
ContextKey string
// KeyGenerator creates a new CSRF token
//
// Optional. Default: utils.UUID
KeyGenerator func() string
// Deprecated: Please use Expiration
CookieExpires time.Duration
// Deprecated: Please use Cookie* related fields
Cookie *fiber.Cookie
// Deprecated: Please use KeyLookup
TokenLookup string
// ErrorHandler is executed when an error is returned from fiber.Handler.
//
// Optional. Default: DefaultErrorHandler
@ -120,26 +105,20 @@ type Config struct {
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c fiber.Ctx) (string, error)
// HandlerContextKey is used to store the CSRF Handler into context
//
// Default: "fiber.csrf.handler"
HandlerContextKey string
}
const HeaderName = "X-Csrf-Token"
// ConfigDefault is the default config
var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
SessionKey: "fiber.csrf.token",
HandlerContextKey: "fiber.csrf.handler",
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
SessionKey: "csrfToken",
}
// default ErrorHandler that process return error from fiber.Handler
@ -158,31 +137,6 @@ func configDefault(config ...Config) Config {
cfg := config[0]
// Set default values
if cfg.TokenLookup != "" {
log.Warn("[CSRF] TokenLookup is deprecated, please use KeyLookup")
cfg.KeyLookup = cfg.TokenLookup
}
if int(cfg.CookieExpires.Seconds()) > 0 {
log.Warn("[CSRF] CookieExpires is deprecated, please use Expiration")
cfg.Expiration = cfg.CookieExpires
}
if cfg.Cookie != nil {
log.Warn("[CSRF] Cookie is deprecated, please use Cookie* related fields")
if cfg.Cookie.Name != "" {
cfg.CookieName = cfg.Cookie.Name
}
if cfg.Cookie.Domain != "" {
cfg.CookieDomain = cfg.Cookie.Domain
}
if cfg.Cookie.Path != "" {
cfg.CookiePath = cfg.Cookie.Path
}
cfg.CookieSecure = cfg.Cookie.Secure
cfg.CookieHTTPOnly = cfg.Cookie.HTTPOnly
if cfg.Cookie.SameSite != "" {
cfg.CookieSameSite = cfg.Cookie.SameSite
}
}
if cfg.KeyLookup == "" {
cfg.KeyLookup = ConfigDefault.KeyLookup
}
@ -204,9 +158,6 @@ func configDefault(config ...Config) Config {
if cfg.SessionKey == "" {
cfg.SessionKey = ConfigDefault.SessionKey
}
if cfg.HandlerContextKey == "" {
cfg.HandlerContextKey = ConfigDefault.HandlerContextKey
}
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")

View File

@ -23,6 +23,16 @@ type CSRFHandler struct {
storageManager *storageManager
}
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
tokenKey contextKey = iota
handlerKey
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
@ -47,14 +57,12 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
// Store the CSRF handler in the context if a context key is specified
if cfg.HandlerContextKey != "" {
c.Locals(cfg.HandlerContextKey, &CSRFHandler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
}
// Store the CSRF handler in the context
c.Locals(handlerKey, &CSRFHandler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
var token string
@ -128,16 +136,34 @@ func New(config ...Config) fiber.Handler {
// Tell the browser that a new header value is generated
c.Vary(fiber.HeaderCookie)
// Store the token in the context if a context key is specified
if cfg.ContextKey != "" {
c.Locals(cfg.ContextKey, token)
}
// Store the token in the context
c.Locals(tokenKey, token)
// Continue stack
return c.Next()
}
}
// TokenFromContext returns the token found in the context
// returns an empty string if the token does not exist
func TokenFromContext(c fiber.Ctx) string {
token, ok := c.Locals(tokenKey).(string)
if !ok {
return ""
}
return token
}
// HandlerFromContext returns the CSRFHandler found in the context
// returns nil if the handler does not exist
func HandlerFromContext(c fiber.Ctx) *CSRFHandler {
handler, ok := c.Locals(handlerKey).(*CSRFHandler)
if !ok {
return nil
}
return handler
}
// getRawFromStorage returns the raw value from the storage for the given token
// returns nil if the token does not exist, is expired or is invalid
func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte {

View File

@ -714,7 +714,8 @@ func Test_CSRF_DeleteToken(t *testing.T) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
if handler, ok := app.NewCtx(ctx).Locals(ConfigDefault.HandlerContextKey).(*CSRFHandler); ok {
handler := HandlerFromContext(app.NewCtx(ctx))
if handler != nil {
if err := handler.DeleteToken(app.NewCtx(ctx)); err != nil {
t.Fatal(err)
}
@ -780,7 +781,8 @@ func Test_CSRF_DeleteToken_WithSession(t *testing.T) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
if handler, ok := app.NewCtx(ctx).Locals(ConfigDefault.HandlerContextKey).(*CSRFHandler); ok {
handler := HandlerFromContext(app.NewCtx(ctx))
if handler != nil {
if err := handler.DeleteToken(app.NewCtx(ctx)); err != nil {
t.Fatal(err)
}

View File

@ -4,10 +4,15 @@ import (
"github.com/gofiber/fiber/v3"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
localsKeyAllowed = "earlydata_allowed"
localsKeyAllowed contextKey = 0 // earlydata_allowed
)
// IsEarlyData returns true if the request is an early-data request
func IsEarly(c fiber.Ctx) bool {
return c.Locals(localsKeyAllowed) != nil
}

View File

@ -12,9 +12,13 @@ import (
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
localsKeyIsFromCache = "idempotency_isfromcache"
localsKeyWasPutToCache = "idempotency_wasputtocache"
localsKeyIsFromCache contextKey = iota //
localsKeyWasPutToCache
)
func IsFromCache(c fiber.Ctx) bool {

View File

@ -31,7 +31,7 @@ func main() {
}))
app.Get("/", func(c fiber.Ctx) error {
token, _ := c.Locals("my_token").(string)
token := c.TokenFromContext(c) // "" is returned if not found
return c.SendString(token)
})

View File

@ -38,10 +38,6 @@ type Config struct {
// Validator is a function to validate key.
Validator func(fiber.Ctx, string) (bool, error)
// Context key to store the bearertoken from the token into context.
// Optional. Default: "token".
ContextKey string
}
// ConfigDefault is the default config
@ -57,7 +53,6 @@ var ConfigDefault = Config{
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
AuthScheme: "Bearer",
ContextKey: "token",
}
// Helper function to set default values
@ -87,9 +82,6 @@ func configDefault(config ...Config) Config {
if cfg.Validator == nil {
panic("fiber: keyauth middleware requires a validator function")
}
if cfg.ContextKey == "" {
cfg.ContextKey = ConfigDefault.ContextKey
}
return cfg
}

View File

@ -9,6 +9,15 @@ import (
"github.com/gofiber/fiber/v3"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
tokenKey contextKey = 0
)
// When there is no request of the key thrown ErrMissingOrMalformedAPIKey
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
@ -54,13 +63,23 @@ func New(config ...Config) fiber.Handler {
valid, err := cfg.Validator(c, key)
if err == nil && valid {
c.Locals(cfg.ContextKey, key)
c.Locals(tokenKey, key)
return cfg.SuccessHandler(c)
}
return cfg.ErrorHandler(c, err)
}
}
// TokenFromContext returns the bearer token from the request context.
// returns an empty string if the token does not exist
func TokenFromContext(c fiber.Ctx) string {
token, ok := c.Locals(tokenKey).(string)
if !ok {
return ""
}
return token
}
// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {

View File

@ -498,10 +498,9 @@ func Test_Response_Header(t *testing.T) {
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: func() string { return "Hello fiber!" },
ContextKey: "requestid",
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: func() string { return "Hello fiber!" },
}))
app.Use(New(Config{
Format: "${respHeader:X-Request-ID}",

View File

@ -21,12 +21,6 @@ type Config struct {
//
// Optional. Default: utils.UUID
Generator func() string
// ContextKey defines the key used when storing the request ID in
// the locals for a specific request.
//
// Optional. Default: requestid
ContextKey interface{}
}
// ConfigDefault is the default config
@ -34,10 +28,9 @@ type Config struct {
// requests made to the server. To conceal this value for better
// privacy, use the "utils.UUIDv4" generator.
var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: utils.UUID,
ContextKey: "requestid",
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: utils.UUID,
}
// Helper function to set default values
@ -57,8 +50,5 @@ func configDefault(config ...Config) Config {
if cfg.Generator == nil {
cfg.Generator = ConfigDefault.Generator
}
if cfg.ContextKey == "" {
cfg.ContextKey = ConfigDefault.ContextKey
}
return cfg
}

View File

@ -4,6 +4,15 @@ import (
"github.com/gofiber/fiber/v3"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
requestIDKey contextKey = iota
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
@ -25,9 +34,18 @@ func New(config ...Config) fiber.Handler {
c.Set(cfg.Header, rid)
// Add the request ID to locals
c.Locals(cfg.ContextKey, rid)
c.Locals(requestIDKey, rid)
// Continue stack
return c.Next()
}
}
// FromContext returns the request ID from context.
// If there is no request ID, an empty string is returned.
func FromContext(c fiber.Ctx) string {
if rid, ok := c.Locals(requestIDKey).(string); ok {
return rid
}
return ""
}

View File

@ -52,24 +52,21 @@ func Test_RequestID_Next(t *testing.T) {
}
// go test -run Test_RequestID_Locals
func Test_RequestID_Locals(t *testing.T) {
func Test_RequestID_FromContext(t *testing.T) {
t.Parallel()
reqID := "ThisIsARequestId"
type ContextKey int
const requestContextKey ContextKey = iota
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
},
ContextKey: requestContextKey,
}))
var ctxVal string
app.Use(func(c fiber.Ctx) error {
ctxVal = c.Locals(requestContextKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
ctxVal = FromContext(c)
return c.Next()
})