🔥 feat: Add support for context.Context in keyauth middleware (#3287)

* feat(middleware): add support to context.Context in keyauth middleware

pretty straightforward option to use context.Context instead of just
fiber.Ctx, tests added accordingly.

* fix(middleware): include import that was missing from previous commit

* fix(middleware): include missing import

* Replace logger with panic

* Update keyauth_test.go

* Update keyauth_test.go

---------

Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
pull/3342/head
vinicius 2025-03-07 04:23:24 -03:00 committed by GitHub
parent 208b9e36ba
commit 4177ab4086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 29 deletions

View File

@ -2,6 +2,7 @@
package keyauth
import (
"context"
"errors"
"fmt"
"net/url"
@ -59,7 +60,10 @@ func New(config ...Config) fiber.Handler {
valid, err := cfg.Validator(c, key)
if err == nil && valid {
// Store in both Locals and Context
c.Locals(tokenKey, key)
ctx := context.WithValue(c.Context(), tokenKey, key)
c.SetContext(ctx)
return cfg.SuccessHandler(c)
}
return cfg.ErrorHandler(c, err)
@ -68,12 +72,20 @@ func New(config ...Config) fiber.Handler {
// TokenFromContext returns the bearer token from the request context.
// returns an empty string if the token does not exist
func TokenFromContext(c fiber.Ctx) string {
token, ok := c.Locals(tokenKey).(string)
if !ok {
return ""
func TokenFromContext(c any) string {
switch ctx := c.(type) {
case context.Context:
if token, ok := ctx.Value(tokenKey).(string); ok {
return token
}
case fiber.Ctx:
if token, ok := ctx.Locals(tokenKey).(string); ok {
return token
}
default:
panic("unsupported context type, expected fiber.Ctx or context.Context")
}
return token
return ""
}
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found

View File

@ -503,33 +503,67 @@ func Test_TokenFromContext_None(t *testing.T) {
}
func Test_TokenFromContext(t *testing.T) {
app := fiber.New()
// Wire up keyauth middleware to set TokenFromContext now
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
// Test that TokenFromContext returns the correct token
t.Run("fiber.Ctx", func(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
// Send
res, err := app.Test(req)
require.NoError(t, err)
t.Run("context.Context", func(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Verify that TokenFromContext works with context.Context
app.Get("/", func(c fiber.Ctx) error {
ctx := c.Context()
token := TokenFromContext(ctx)
return c.SendString(token)
})
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
})
t.Run("invalid context type", func(t *testing.T) {
require.Panics(t, func() {
_ = TokenFromContext("invalid")
})
})
}
func Test_AuthSchemeToken(t *testing.T) {