mirror of https://github.com/gofiber/fiber.git
🔥 Feature: Add support for custom KeyLookup functions in the Keyauth middleware (#3028)
* port over FallbackKeyLookups from v2 middleware to v3 Signed-off-by: Dave Lee <dave@gray101.com> * bot pointed out that I missed the format variable Signed-off-by: Dave Lee <dave@gray101.com> * fix lint and gofumpt issues Signed-off-by: Dave Lee <dave@gray101.com> * major revision: instead of FallbackKeyLookups, expose CustomKeyLookup as function, with utility functions to make creating these easy Signed-off-by: Dave Lee <dave@gray101.com> * add more tests to boost coverage Signed-off-by: Dave Lee <dave@gray101.com> * teardown code and cleanup Signed-off-by: Dave Lee <dave@gray101.com> * test fixes Signed-off-by: Dave Lee <dave@gray101.com> * slight boost to test coverage Signed-off-by: Dave Lee <dave@gray101.com> * docs: fix md table alignment * fix comments - change some names, expose functions, improve docs Signed-off-by: Dave Lee <dave@gray101.com> * missed one old name Signed-off-by: Dave Lee <dave@gray101.com> * fix some suggestions from the bot - error messages, test coverage, mark purely defensive code Signed-off-by: Dave Lee <dave@gray101.com> --------- Signed-off-by: Dave Lee <dave@gray101.com> Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Co-authored-by: Jason McNeil <sixcolors@mac.com> Co-authored-by: RW <rene@gofiber.io>export-buildtree
parent
c9b7b1aefb
commit
2db1858513
|
@ -214,14 +214,15 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000
|
|||
|
||||
## Config
|
||||
|
||||
| Property | Type | Description | Default |
|
||||
|:---------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
|
||||
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
|
||||
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
|
||||
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
|
||||
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
|
||||
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
|
||||
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
|
||||
| Property | Type | Description | Default |
|
||||
|:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
|
||||
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
|
||||
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
|
||||
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
|
||||
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract the key from the request. | "header:Authorization" |
|
||||
| CustomKeyLookup | `KeyLookupFunc` aka `func(c fiber.Ctx) (string, error)` | If more complex logic is required to extract the key from the request, an arbitrary function to extract it can be specified here. Utility helper functions are described below. | `nil` |
|
||||
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
|
||||
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
|
||||
|
||||
## Default Config
|
||||
|
||||
|
@ -237,6 +238,13 @@ var ConfigDefault = Config{
|
|||
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
|
||||
},
|
||||
KeyLookup: "header:" + fiber.HeaderAuthorization,
|
||||
CustomKeyLookup: nil,
|
||||
AuthScheme: "Bearer",
|
||||
}
|
||||
```
|
||||
|
||||
## CustomKeyLookup
|
||||
|
||||
Two public utility functions are provided that may be useful when creating custom extraction:
|
||||
* `DefaultKeyLookup(keyLookup string, authScheme string)`: This is the function that implements the default `KeyLookup` behavior, exposed to be used as a component of custom parsing logic
|
||||
* `MultipleKeySourceLookup(keyLookups []string, authScheme string)`: Creates a CustomKeyLookup function that checks each listed source using the above function until a key is found or the options are all exhausted. For example, `MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "cookie:apikey"}, "Bearer")` would first check the standard Authorization header, checks the `x-api-key` header next, and finally checks for a cookie named `apikey`. If any of these contain a valid API key, the request continues. Otherwise, an error is returned.
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
type KeyLookupFunc func(c fiber.Ctx) (string, error)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip middleware.
|
||||
|
@ -32,6 +34,8 @@ type Config struct {
|
|||
// - "cookie:<name>"
|
||||
KeyLookup string
|
||||
|
||||
CustomKeyLookup KeyLookupFunc
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
AuthScheme string
|
||||
|
@ -51,8 +55,9 @@ var ConfigDefault = Config{
|
|||
}
|
||||
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
|
||||
},
|
||||
KeyLookup: "header:" + fiber.HeaderAuthorization,
|
||||
AuthScheme: "Bearer",
|
||||
KeyLookup: "header:" + fiber.HeaderAuthorization,
|
||||
CustomKeyLookup: nil,
|
||||
AuthScheme: "Bearer",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
|
|
|
@ -3,6 +3,7 @@ package keyauth
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
|
@ -34,17 +35,12 @@ func New(config ...Config) fiber.Handler {
|
|||
cfg := configDefault(config...)
|
||||
|
||||
// Initialize
|
||||
parts := strings.Split(cfg.KeyLookup, ":")
|
||||
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
|
||||
switch parts[0] {
|
||||
case query:
|
||||
extractor = keyFromQuery(parts[1])
|
||||
case form:
|
||||
extractor = keyFromForm(parts[1])
|
||||
case param:
|
||||
extractor = keyFromParam(parts[1])
|
||||
case cookie:
|
||||
extractor = keyFromCookie(parts[1])
|
||||
if cfg.CustomKeyLookup == nil {
|
||||
var err error
|
||||
cfg.CustomKeyLookup, err = DefaultKeyLookup(cfg.KeyLookup, cfg.AuthScheme)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("unable to create lookup function: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Return middleware handler
|
||||
|
@ -55,7 +51,7 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
|
||||
// Extract and verify key
|
||||
key, err := extractor(c)
|
||||
key, err := cfg.CustomKeyLookup(c)
|
||||
if err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
|
@ -80,8 +76,53 @@ func TokenFromContext(c fiber.Ctx) string {
|
|||
return token
|
||||
}
|
||||
|
||||
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
|
||||
// Each element should be specified according to the format used in KeyLookup
|
||||
func MultipleKeySourceLookup(keyLookups []string, authScheme string) (KeyLookupFunc, error) {
|
||||
subExtractors := map[string]KeyLookupFunc{}
|
||||
var err error
|
||||
for _, keyLookup := range keyLookups {
|
||||
subExtractors[keyLookup], err = DefaultKeyLookup(keyLookup, authScheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return func(c fiber.Ctx) (string, error) {
|
||||
for keyLookup, subExtractor := range subExtractors {
|
||||
res, err := subExtractor(c)
|
||||
if err == nil && res != "" {
|
||||
return res, nil
|
||||
}
|
||||
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
|
||||
// Defensive Code - not currently possible to hit
|
||||
return "", fmt.Errorf("[%s] %w", keyLookup, err)
|
||||
}
|
||||
}
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DefaultKeyLookup(keyLookup, authScheme string) (KeyLookupFunc, error) {
|
||||
parts := strings.Split(keyLookup, ":")
|
||||
if len(parts) <= 1 {
|
||||
return nil, fmt.Errorf("invalid keyLookup: %q, expected format 'source:name'", keyLookup)
|
||||
}
|
||||
extractor := KeyFromHeader(parts[1], authScheme) // in the event of an invalid prefix, it is interpreted as header:
|
||||
switch parts[0] {
|
||||
case query:
|
||||
extractor = KeyFromQuery(parts[1])
|
||||
case form:
|
||||
extractor = KeyFromForm(parts[1])
|
||||
case param:
|
||||
extractor = KeyFromParam(parts[1])
|
||||
case cookie:
|
||||
extractor = KeyFromCookie(parts[1])
|
||||
}
|
||||
return extractor, nil
|
||||
}
|
||||
|
||||
// keyFromHeader returns a function that extracts api key from the request header.
|
||||
func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) {
|
||||
func KeyFromHeader(header, authScheme string) KeyLookupFunc {
|
||||
return func(c fiber.Ctx) (string, error) {
|
||||
auth := c.Get(header)
|
||||
l := len(authScheme)
|
||||
|
@ -96,7 +137,7 @@ func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error)
|
|||
}
|
||||
|
||||
// keyFromQuery returns a function that extracts api key from the query string.
|
||||
func keyFromQuery(param string) func(c fiber.Ctx) (string, error) {
|
||||
func KeyFromQuery(param string) KeyLookupFunc {
|
||||
return func(c fiber.Ctx) (string, error) {
|
||||
key := fiber.Query[string](c, param)
|
||||
if key == "" {
|
||||
|
@ -107,7 +148,7 @@ func keyFromQuery(param string) func(c fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// keyFromForm returns a function that extracts api key from the form.
|
||||
func keyFromForm(param string) func(c fiber.Ctx) (string, error) {
|
||||
func KeyFromForm(param string) KeyLookupFunc {
|
||||
return func(c fiber.Ctx) (string, error) {
|
||||
key := c.FormValue(param)
|
||||
if key == "" {
|
||||
|
@ -118,7 +159,7 @@ func keyFromForm(param string) func(c fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// keyFromParam returns a function that extracts api key from the url param string.
|
||||
func keyFromParam(param string) func(c fiber.Ctx) (string, error) {
|
||||
func KeyFromParam(param string) KeyLookupFunc {
|
||||
return func(c fiber.Ctx) (string, error) {
|
||||
key, err := url.PathUnescape(c.Params(param))
|
||||
if err != nil {
|
||||
|
@ -129,7 +170,7 @@ func keyFromParam(param string) func(c fiber.Ctx) (string, error) {
|
|||
}
|
||||
|
||||
// keyFromCookie returns a function that extracts api key from the named cookie.
|
||||
func keyFromCookie(name string) func(c fiber.Ctx) (string, error) {
|
||||
func KeyFromCookie(name string) KeyLookupFunc {
|
||||
return func(c fiber.Ctx) (string, error) {
|
||||
key := c.Cookies(name)
|
||||
if key == "" {
|
||||
|
|
|
@ -130,6 +130,109 @@ func Test_AuthSources(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPanicOnInvalidConfiguration(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
authMiddleware := New(Config{
|
||||
KeyLookup: "invalid",
|
||||
})
|
||||
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
|
||||
app := fiber.New()
|
||||
defer func() { // testing panics, defer block to ensure cleanup
|
||||
err := app.Shutdown()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
app.Use(authMiddleware)
|
||||
}, "should panic if Validator is missing")
|
||||
|
||||
require.Panics(t, func() {
|
||||
authMiddleware := New(Config{
|
||||
KeyLookup: "invalid",
|
||||
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
})
|
||||
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
|
||||
app := fiber.New()
|
||||
defer func() { // testing panics, defer block to ensure cleanup
|
||||
err := app.Shutdown()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
app.Use(authMiddleware)
|
||||
}, "should panic if CustomKeyLookup is not set AND KeyLookup has an invalid value")
|
||||
}
|
||||
|
||||
func TestCustomKeyUtilityFunctionErrors(t *testing.T) {
|
||||
const (
|
||||
scheme = "Bearer"
|
||||
)
|
||||
|
||||
// Invalid element while parsing
|
||||
_, err := DefaultKeyLookup("invalid", scheme)
|
||||
require.Error(t, err, "DefaultKeyLookup should fail for 'invalid' keyLookup")
|
||||
|
||||
_, err = MultipleKeySourceLookup([]string{"header:key", "invalid"}, scheme)
|
||||
require.Error(t, err, "MultipleKeySourceLookup should fail for 'invalid' keyLookup")
|
||||
}
|
||||
|
||||
func TestMultipleKeyLookup(t *testing.T) {
|
||||
const (
|
||||
desc = "auth with correct key"
|
||||
success = "Success!"
|
||||
scheme = "Bearer"
|
||||
)
|
||||
|
||||
// setup the fiber endpoint
|
||||
app := fiber.New()
|
||||
|
||||
customKeyLookup, err := MultipleKeySourceLookup([]string{"header:key", "cookie:key", "query:key"}, scheme)
|
||||
require.NoError(t, err)
|
||||
|
||||
authMiddleware := New(Config{
|
||||
CustomKeyLookup: customKeyLookup,
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
})
|
||||
app.Use(authMiddleware)
|
||||
app.Get("/foo", func(c fiber.Ctx) error {
|
||||
return c.SendString(success)
|
||||
})
|
||||
|
||||
// construct the test HTTP request
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
|
||||
require.NoError(t, err)
|
||||
q := req.URL.Query()
|
||||
q.Add("key", CorrectKey)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
res, err := app.Test(req, -1)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// test the body of the request
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.Equal(t, 200, res.StatusCode, desc)
|
||||
// body
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, success, string(body), desc)
|
||||
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// construct a second request without proper key
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
|
||||
require.NoError(t, err)
|
||||
res, err = app.Test(req, -1)
|
||||
require.NoError(t, err)
|
||||
errBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody))
|
||||
}
|
||||
|
||||
func Test_MultipleKeyAuth(t *testing.T) {
|
||||
// setup the fiber endpoint
|
||||
app := fiber.New()
|
||||
|
@ -376,6 +479,55 @@ func Test_CustomNextFunc(t *testing.T) {
|
|||
require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error())
|
||||
}
|
||||
|
||||
func Test_TokenFromContext_None(t *testing.T) {
|
||||
app := fiber.New()
|
||||
// Define a test handler that checks TokenFromContext
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString(TokenFromContext(c))
|
||||
})
|
||||
|
||||
// Verify a "" is sent back if nothing sets the token on the context.
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
// Send
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, body)
|
||||
}
|
||||
|
||||
func Test_TokenFromContext(t *testing.T) {
|
||||
app := fiber.New()
|
||||
// Wire up keyauth middleware to set TokenFromContext now
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(_ fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
// Define a test handler that checks TokenFromContext
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
return c.SendString(TokenFromContext(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", "Basic "+CorrectKey)
|
||||
// Send
|
||||
res, err := app.Test(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, CorrectKey, string(body))
|
||||
}
|
||||
|
||||
func Test_AuthSchemeToken(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
|
|
Loading…
Reference in New Issue