mirror of https://github.com/gofiber/fiber.git
Revert "🔥 feat: Add support for context.Context in keyauth middleware (#3287)"
This reverts commit 4177ab4086
.
revert-3287-context-middleware
parent
87f3f0c8b6
commit
811732cd50
|
@ -2,7 +2,6 @@
|
|||
package keyauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
@ -60,10 +59,7 @@ func New(config ...Config) fiber.Handler {
|
|||
valid, err := cfg.Validator(c, key)
|
||||
|
||||
if err == nil && valid {
|
||||
// Store in both Locals and Context
|
||||
c.Locals(tokenKey, key)
|
||||
ctx := context.WithValue(c.Context(), tokenKey, key)
|
||||
c.SetContext(ctx)
|
||||
return cfg.SuccessHandler(c)
|
||||
}
|
||||
return cfg.ErrorHandler(c, err)
|
||||
|
@ -72,20 +68,12 @@ func New(config ...Config) fiber.Handler {
|
|||
|
||||
// TokenFromContext returns the bearer token from the request context.
|
||||
// returns an empty string if the token does not exist
|
||||
func TokenFromContext(c any) string {
|
||||
switch ctx := c.(type) {
|
||||
case context.Context:
|
||||
if token, ok := ctx.Value(tokenKey).(string); ok {
|
||||
return token
|
||||
}
|
||||
case fiber.Ctx:
|
||||
if token, ok := ctx.Locals(tokenKey).(string); ok {
|
||||
return token
|
||||
}
|
||||
default:
|
||||
panic("unsupported context type, expected fiber.Ctx or context.Context")
|
||||
func TokenFromContext(c fiber.Ctx) string {
|
||||
token, ok := c.Locals(tokenKey).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
return token
|
||||
}
|
||||
|
||||
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
|
||||
|
|
|
@ -503,67 +503,33 @@ func Test_TokenFromContext_None(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_TokenFromContext(t *testing.T) {
|
||||
// Test that TokenFromContext returns the correct token
|
||||
t.Run("fiber.Ctx", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString(TokenFromContext(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
app := fiber.New()
|
||||
// Wire up keyauth middleware to set TokenFromContext now
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
// Define a test handler that checks TokenFromContext
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString(TokenFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("context.Context", func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
// Verify that TokenFromContext works with context.Context
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
token := TokenFromContext(ctx)
|
||||
return c.SendString(token)
|
||||
})
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
// Send
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
})
|
||||
|
||||
t.Run("invalid context type", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_ = TokenFromContext("invalid")
|
||||
})
|
||||
})
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
}
|
||||
|
||||
func Test_AuthSchemeToken(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue