🔥 feat: Add support for context.Context in keyauth middleware (#3287)

* feat(middleware): add support to context.Context in keyauth middleware

pretty straightforward option to use context.Context instead of just
fiber.Ctx, tests added accordingly.

* fix(middleware): include import that was missing from previous commit

* fix(middleware): include missing import

* Replace logger with panic

* Update keyauth_test.go

* Update keyauth_test.go

---------

Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
pull/3342/head
vinicius 2025-03-07 04:23:24 -03:00 committed by GitHub
parent 208b9e36ba
commit 4177ab4086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 29 deletions

View File

@ -2,6 +2,7 @@
package keyauth package keyauth
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
@ -59,7 +60,10 @@ func New(config ...Config) fiber.Handler {
valid, err := cfg.Validator(c, key) valid, err := cfg.Validator(c, key)
if err == nil && valid { if err == nil && valid {
// Store in both Locals and Context
c.Locals(tokenKey, key) c.Locals(tokenKey, key)
ctx := context.WithValue(c.Context(), tokenKey, key)
c.SetContext(ctx)
return cfg.SuccessHandler(c) return cfg.SuccessHandler(c)
} }
return cfg.ErrorHandler(c, err) return cfg.ErrorHandler(c, err)
@ -68,12 +72,20 @@ func New(config ...Config) fiber.Handler {
// TokenFromContext returns the bearer token from the request context. // TokenFromContext returns the bearer token from the request context.
// returns an empty string if the token does not exist // returns an empty string if the token does not exist
func TokenFromContext(c fiber.Ctx) string { func TokenFromContext(c any) string {
token, ok := c.Locals(tokenKey).(string) switch ctx := c.(type) {
if !ok { case context.Context:
return "" if token, ok := ctx.Value(tokenKey).(string); ok {
return token
}
case fiber.Ctx:
if token, ok := ctx.Locals(tokenKey).(string); ok {
return token
}
default:
panic("unsupported context type, expected fiber.Ctx or context.Context")
} }
return token return ""
} }
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found // MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found

View File

@ -503,33 +503,67 @@ func Test_TokenFromContext_None(t *testing.T) {
} }
func Test_TokenFromContext(t *testing.T) { func Test_TokenFromContext(t *testing.T) {
app := fiber.New() // Test that TokenFromContext returns the correct token
// Wire up keyauth middleware to set TokenFromContext now t.Run("fiber.Ctx", func(t *testing.T) {
app.Use(New(Config{ app := fiber.New()
KeyLookup: "header:Authorization", app.Use(New(Config{
AuthScheme: "Basic", KeyLookup: "header:Authorization",
Validator: func(_ fiber.Ctx, key string) (bool, error) { AuthScheme: "Basic",
if key == CorrectKey { Validator: func(_ fiber.Ctx, key string) (bool, error) {
return true, nil if key == CorrectKey {
} return true, nil
return false, ErrMissingOrMalformedAPIKey }
}, return false, ErrMissingOrMalformedAPIKey
})) },
// Define a test handler that checks TokenFromContext }))
app.Get("/", func(c fiber.Ctx) error { app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c)) return c.SendString(TokenFromContext(c))
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
}) })
req := httptest.NewRequest(fiber.MethodGet, "/", nil) t.Run("context.Context", func(t *testing.T) {
req.Header.Add("Authorization", "Basic "+CorrectKey) app := fiber.New()
// Send app.Use(New(Config{
res, err := app.Test(req) KeyLookup: "header:Authorization",
require.NoError(t, err) AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Verify that TokenFromContext works with context.Context
app.Get("/", func(c fiber.Ctx) error {
ctx := c.Context()
token := TokenFromContext(ctx)
return c.SendString(token)
})
// Read the response body into a string req := httptest.NewRequest(fiber.MethodGet, "/", nil)
body, err := io.ReadAll(res.Body) req.Header.Add("Authorization", "Basic "+CorrectKey)
require.NoError(t, err) res, err := app.Test(req)
require.Equal(t, CorrectKey, string(body)) require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
})
t.Run("invalid context type", func(t *testing.T) {
require.Panics(t, func() {
_ = TokenFromContext("invalid")
})
})
} }
func Test_AuthSchemeToken(t *testing.T) { func Test_AuthSchemeToken(t *testing.T) {