fiber/middleware/keyauth/keyauth_test.go

651 lines
18 KiB
Go

package keyauth
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
const CorrectKey = "specials: !$%,.#\"!?~`<>@$^*(){}[]|/\\123"
var testConfig = fiber.TestConfig{
Timeout: 0,
}
func Test_AuthSources(t *testing.T) {
// define test cases
testSources := []string{"header", "cookie", "query", "param", "form"}
tests := []struct {
route string
authTokenName string
description string
APIKey string
expectedBody string
expectedCode int
}{
{
route: "/",
authTokenName: "access_token",
description: "auth with correct key",
APIKey: CorrectKey,
expectedCode: 200,
expectedBody: "Success!",
},
{
route: "/",
authTokenName: "access_token",
description: "auth with no key",
APIKey: "",
expectedCode: 401, // 404 in case of param authentication
expectedBody: "missing or malformed API Key",
},
{
route: "/",
authTokenName: "access_token",
description: "auth with wrong key",
APIKey: "WRONGKEY",
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
}
for _, authSource := range testSources {
t.Run(authSource, func(t *testing.T) {
for _, test := range tests {
// setup the fiber endpoint
// note that if UnescapePath: false (the default)
// escaped characters (such as `\"`) will not be handled correctly in the tests
app := fiber.New(fiber.Config{UnescapePath: true})
authMiddleware := New(Config{
KeyLookup: authSource + ":" + test.authTokenName,
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
var route string
if authSource == param {
route = test.route + ":" + test.authTokenName
app.Use(route, authMiddleware)
} else {
route = test.route
app.Use(authMiddleware)
}
app.Get(route, func(c fiber.Ctx) error {
return c.SendString("Success!")
})
// construct the test HTTP request
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
require.NoError(t, err)
// setup the apikey for the different auth schemes
switch authSource {
case "header":
req.Header.Set(test.authTokenName, test.APIKey)
case "cookie":
req.Header.Set("Cookie", test.authTokenName+"="+test.APIKey)
case "query", "form":
q := req.URL.Query()
q.Add(test.authTokenName, test.APIKey)
req.URL.RawQuery = q.Encode()
case "param":
r := req.URL.Path
r += url.PathEscape(test.APIKey)
req.URL.Path = r
}
res, err := app.Test(req, testConfig)
require.NoError(t, err, test.description)
// test the body of the request
body, err := io.ReadAll(res.Body)
// for param authentication, the route would be /:access_token
// when the access_token is empty, it leads to a 404 (not found)
// not a 401 (auth error)
if authSource == "param" && test.APIKey == "" {
test.expectedCode = 404
test.expectedBody = "Cannot GET /"
}
require.Equal(t, test.expectedCode, res.StatusCode, test.description)
// body
require.NoError(t, err, test.description)
require.Equal(t, test.expectedBody, string(body), test.description)
err = res.Body.Close()
require.NoError(t, err)
}
})
}
}
func TestPanicOnInvalidConfiguration(t *testing.T) {
require.Panics(t, func() {
authMiddleware := New(Config{
KeyLookup: "invalid",
})
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
app := fiber.New()
defer func() { // testing panics, defer block to ensure cleanup
err := app.Shutdown()
require.NoError(t, err)
}()
app.Use(authMiddleware)
}, "should panic if Validator is missing")
require.Panics(t, func() {
authMiddleware := New(Config{
KeyLookup: "invalid",
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return true, nil
},
})
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
app := fiber.New()
defer func() { // testing panics, defer block to ensure cleanup
err := app.Shutdown()
require.NoError(t, err)
}()
app.Use(authMiddleware)
}, "should panic if CustomKeyLookup is not set AND KeyLookup has an invalid value")
}
func TestCustomKeyUtilityFunctionErrors(t *testing.T) {
const (
scheme = "Bearer"
)
// Invalid element while parsing
_, err := DefaultKeyLookup("invalid", scheme)
require.Error(t, err, "DefaultKeyLookup should fail for 'invalid' keyLookup")
_, err = MultipleKeySourceLookup([]string{"header:key", "invalid"}, scheme)
require.Error(t, err, "MultipleKeySourceLookup should fail for 'invalid' keyLookup")
}
func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
scheme = "Bearer"
)
// setup the fiber endpoint
app := fiber.New()
customKeyLookup, err := MultipleKeySourceLookup([]string{"header:key", "cookie:key", "query:key"}, scheme)
require.NoError(t, err)
authMiddleware := New(Config{
CustomKeyLookup: customKeyLookup,
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c fiber.Ctx) error {
return c.SendString(success)
})
// construct the test HTTP request
var req *http.Request
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()
res, err := app.Test(req, testConfig)
require.NoError(t, err)
// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, 200, res.StatusCode, desc)
// body
require.NoError(t, err)
require.Equal(t, success, string(body), desc)
err = res.Body.Close()
require.NoError(t, err)
// construct a second request without proper key
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
res, err = app.Test(req, testConfig)
require.NoError(t, err)
errBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody))
}
func Test_MultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
// setup keyauth for /auth1
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.OriginalURL() != "/auth1"
},
KeyLookup: "header:key",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == "password1" {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// setup keyauth for /auth2
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.OriginalURL() != "/auth2"
},
KeyLookup: "header:key",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == "password2" {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("No auth needed!")
})
app.Get("/auth1", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated for auth1!")
})
app.Get("/auth2", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated for auth2!")
})
// define test cases
tests := []struct {
route string
description string
APIKey string
expectedBody string
expectedCode int
}{
// No auth needed for /
{
route: "/",
description: "No password needed",
APIKey: "",
expectedCode: 200,
expectedBody: "No auth needed!",
},
// auth needed for auth1
{
route: "/auth1",
description: "Normal Authentication Case",
APIKey: "password1",
expectedCode: 200,
expectedBody: "Successfully authenticated for auth1!",
},
{
route: "/auth1",
description: "Wrong API Key",
APIKey: "WRONG KEY",
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
{
route: "/auth1",
description: "Wrong API Key",
APIKey: "", // NO KEY
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
// Auth 2 has a different password
{
route: "/auth2",
description: "Normal Authentication Case for auth2",
APIKey: "password2",
expectedCode: 200,
expectedBody: "Successfully authenticated for auth2!",
},
{
route: "/auth2",
description: "Wrong API Key",
APIKey: "WRONG KEY",
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
{
route: "/auth2",
description: "Wrong API Key",
APIKey: "", // NO KEY
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
}
// run the tests
for _, test := range tests {
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
require.NoError(t, err)
if test.APIKey != "" {
req.Header.Set("key", test.APIKey)
}
res, err := app.Test(req, testConfig)
require.NoError(t, err, test.description)
// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, test.expectedCode, res.StatusCode, test.description)
// body
require.NoError(t, err, test.description)
require.Equal(t, test.expectedBody, string(body), test.description)
}
}
func Test_CustomSuccessAndFailureHandlers(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SuccessHandler: func(c fiber.Ctx) error {
return c.Status(fiber.StatusOK).SendString("API key is valid and request was handled by custom success handler")
},
ErrorHandler: func(c fiber.Ctx, _ error) error {
return c.Status(fiber.StatusUnauthorized).SendString("API key is invalid and request was handled by custom error handler")
},
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that should not be called
app.Get("/", func(_ fiber.Ctx) error {
t.Error("Test handler should not be called")
return nil
})
// Create a request without an API key and send it to the app
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, "API key is invalid and request was handled by custom error handler", string(body))
// Create a request with a valid API key in the Authorization header
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Bearer "+CorrectKey)
// Send the request to the app
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid and request was handled by custom success handler", string(body))
}
func Test_CustomNextFunc(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Path() == "/allowed"
},
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/allowed", func(c fiber.Ctx) error {
return c.SendString("API key is valid and request was allowed by custom filter")
})
// Create a request with the "/allowed" path and send it to the app
req := httptest.NewRequest(fiber.MethodGet, "/allowed", nil)
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid and request was allowed by custom filter", string(body))
// Create a request with a different path and send it to the app without correct key
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error())
// Create a request with a different path and send it to the app with correct key
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error())
}
func Test_TokenFromContext_None(t *testing.T) {
app := fiber.New()
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})
// Verify a "" is sent back if nothing sets the token on the context.
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
// Send
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Empty(t, body)
}
func Test_TokenFromContext(t *testing.T) {
// Test that TokenFromContext returns the correct token
t.Run("fiber.Ctx", func(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
})
t.Run("context.Context", func(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Verify that TokenFromContext works with context.Context
app.Get("/", func(c fiber.Ctx) error {
ctx := c.Context()
token := TokenFromContext(ctx)
return c.SendString(token)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
})
t.Run("invalid context type", func(t *testing.T) {
require.Panics(t, func() {
_ = TokenFromContext("invalid")
})
})
}
func Test_AuthSchemeToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
AuthScheme: "Token",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("API key is valid")
})
// Create a request with a valid API key in the "Token" Authorization header
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Token "+CorrectKey)
// Send the request to the app
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid", string(body))
}
func Test_AuthSchemeBasic(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("API key is valid")
})
// Create a request without an API key and Send the request to the app
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error())
// Create a request with a valid API key in the "Authorization" header using the "Basic" scheme
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
// Send the request to the app
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid", string(body))
}