mirror of https://github.com/gofiber/fiber.git
🔥 feat: Add support for context.Context in keyauth middleware (#3287)
* feat(middleware): add support to context.Context in keyauth middleware pretty straightforward option to use context.Context instead of just fiber.Ctx, tests added accordingly. * fix(middleware): include import that was missing from previous commit * fix(middleware): include missing import * Replace logger with panic * Update keyauth_test.go * Update keyauth_test.go --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>pull/3342/head
parent
208b9e36ba
commit
4177ab4086
|
@ -2,6 +2,7 @@
|
||||||
package keyauth
|
package keyauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -59,7 +60,10 @@ func New(config ...Config) fiber.Handler {
|
||||||
valid, err := cfg.Validator(c, key)
|
valid, err := cfg.Validator(c, key)
|
||||||
|
|
||||||
if err == nil && valid {
|
if err == nil && valid {
|
||||||
|
// Store in both Locals and Context
|
||||||
c.Locals(tokenKey, key)
|
c.Locals(tokenKey, key)
|
||||||
|
ctx := context.WithValue(c.Context(), tokenKey, key)
|
||||||
|
c.SetContext(ctx)
|
||||||
return cfg.SuccessHandler(c)
|
return cfg.SuccessHandler(c)
|
||||||
}
|
}
|
||||||
return cfg.ErrorHandler(c, err)
|
return cfg.ErrorHandler(c, err)
|
||||||
|
@ -68,12 +72,20 @@ func New(config ...Config) fiber.Handler {
|
||||||
|
|
||||||
// TokenFromContext returns the bearer token from the request context.
|
// TokenFromContext returns the bearer token from the request context.
|
||||||
// returns an empty string if the token does not exist
|
// returns an empty string if the token does not exist
|
||||||
func TokenFromContext(c fiber.Ctx) string {
|
func TokenFromContext(c any) string {
|
||||||
token, ok := c.Locals(tokenKey).(string)
|
switch ctx := c.(type) {
|
||||||
if !ok {
|
case context.Context:
|
||||||
return ""
|
if token, ok := ctx.Value(tokenKey).(string); ok {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
case fiber.Ctx:
|
||||||
|
if token, ok := ctx.Locals(tokenKey).(string); ok {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unsupported context type, expected fiber.Ctx or context.Context")
|
||||||
}
|
}
|
||||||
return token
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
|
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
|
||||||
|
|
|
@ -503,33 +503,67 @@ func Test_TokenFromContext_None(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_TokenFromContext(t *testing.T) {
|
func Test_TokenFromContext(t *testing.T) {
|
||||||
app := fiber.New()
|
// Test that TokenFromContext returns the correct token
|
||||||
// Wire up keyauth middleware to set TokenFromContext now
|
t.Run("fiber.Ctx", func(t *testing.T) {
|
||||||
app.Use(New(Config{
|
app := fiber.New()
|
||||||
KeyLookup: "header:Authorization",
|
app.Use(New(Config{
|
||||||
AuthScheme: "Basic",
|
KeyLookup: "header:Authorization",
|
||||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
AuthScheme: "Basic",
|
||||||
if key == CorrectKey {
|
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||||
return true, nil
|
if key == CorrectKey {
|
||||||
}
|
return true, nil
|
||||||
return false, ErrMissingOrMalformedAPIKey
|
}
|
||||||
},
|
return false, ErrMissingOrMalformedAPIKey
|
||||||
}))
|
},
|
||||||
// Define a test handler that checks TokenFromContext
|
}))
|
||||||
app.Get("/", func(c fiber.Ctx) error {
|
app.Get("/", func(c fiber.Ctx) error {
|
||||||
return c.SendString(TokenFromContext(c))
|
return c.SendString(TokenFromContext(c))
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
|
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||||
|
res, err := app.Test(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, CorrectKey, string(body))
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
t.Run("context.Context", func(t *testing.T) {
|
||||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
app := fiber.New()
|
||||||
// Send
|
app.Use(New(Config{
|
||||||
res, err := app.Test(req)
|
KeyLookup: "header:Authorization",
|
||||||
require.NoError(t, err)
|
AuthScheme: "Basic",
|
||||||
|
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||||
|
if key == CorrectKey {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, ErrMissingOrMalformedAPIKey
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
// Verify that TokenFromContext works with context.Context
|
||||||
|
app.Get("/", func(c fiber.Ctx) error {
|
||||||
|
ctx := c.Context()
|
||||||
|
token := TokenFromContext(ctx)
|
||||||
|
return c.SendString(token)
|
||||||
|
})
|
||||||
|
|
||||||
// Read the response body into a string
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||||
body, err := io.ReadAll(res.Body)
|
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||||
require.NoError(t, err)
|
res, err := app.Test(req)
|
||||||
require.Equal(t, CorrectKey, string(body))
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, CorrectKey, string(body))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid context type", func(t *testing.T) {
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_ = TokenFromContext("invalid")
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_AuthSchemeToken(t *testing.T) {
|
func Test_AuthSchemeToken(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue