diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index e245ba42..54ecdbe5 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -2,6 +2,7 @@ package keyauth import ( + "context" "errors" "fmt" "net/url" @@ -59,7 +60,10 @@ func New(config ...Config) fiber.Handler { valid, err := cfg.Validator(c, key) if err == nil && valid { + // Store in both Locals and Context c.Locals(tokenKey, key) + ctx := context.WithValue(c.Context(), tokenKey, key) + c.SetContext(ctx) return cfg.SuccessHandler(c) } return cfg.ErrorHandler(c, err) @@ -68,12 +72,20 @@ func New(config ...Config) fiber.Handler { // TokenFromContext returns the bearer token from the request context. // returns an empty string if the token does not exist -func TokenFromContext(c fiber.Ctx) string { - token, ok := c.Locals(tokenKey).(string) - if !ok { - return "" +func TokenFromContext(c any) string { + switch ctx := c.(type) { + case context.Context: + if token, ok := ctx.Value(tokenKey).(string); ok { + return token + } + case fiber.Ctx: + if token, ok := ctx.Locals(tokenKey).(string); ok { + return token + } + default: + panic("unsupported context type, expected fiber.Ctx or context.Context") } - return token + return "" } // MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 72c9d3c1..27c4e5a0 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -503,33 +503,67 @@ func Test_TokenFromContext_None(t *testing.T) { } func Test_TokenFromContext(t *testing.T) { - app := fiber.New() - // Wire up keyauth middleware to set TokenFromContext now - app.Use(New(Config{ - KeyLookup: "header:Authorization", - AuthScheme: "Basic", - Validator: func(_ fiber.Ctx, key string) (bool, error) { - if key == CorrectKey { - return true, nil - } - return false, ErrMissingOrMalformedAPIKey - }, - })) - // Define a test handler that checks TokenFromContext - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(TokenFromContext(c)) + // Test that TokenFromContext returns the correct token + t.Run("fiber.Ctx", func(t *testing.T) { + app := fiber.New() + app.Use(New(Config{ + KeyLookup: "header:Authorization", + AuthScheme: "Basic", + Validator: func(_ fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(TokenFromContext(c)) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Add("Authorization", "Basic "+CorrectKey) + res, err := app.Test(req) + require.NoError(t, err) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, CorrectKey, string(body)) }) - req := httptest.NewRequest(fiber.MethodGet, "/", nil) - req.Header.Add("Authorization", "Basic "+CorrectKey) - // Send - res, err := app.Test(req) - require.NoError(t, err) + t.Run("context.Context", func(t *testing.T) { + app := fiber.New() + app.Use(New(Config{ + KeyLookup: "header:Authorization", + AuthScheme: "Basic", + Validator: func(_ fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + // Verify that TokenFromContext works with context.Context + app.Get("/", func(c fiber.Ctx) error { + ctx := c.Context() + token := TokenFromContext(ctx) + return c.SendString(token) + }) - // Read the response body into a string - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, CorrectKey, string(body)) + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Add("Authorization", "Basic "+CorrectKey) + res, err := app.Test(req) + require.NoError(t, err) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, CorrectKey, string(body)) + }) + + t.Run("invalid context type", func(t *testing.T) { + require.Panics(t, func() { + _ = TokenFromContext("invalid") + }) + }) } func Test_AuthSchemeToken(t *testing.T) {