fiber/security/security_test.go
2025-05-24 00:40:41 -04:00

257 lines
6.8 KiB
Go

package security
import (
"encoding/base64"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func setupApp(handler fiber.Handler) *fiber.App {
app := fiber.New()
app.Get("/", handler)
return app
}
func Test_APIKeyCookie(t *testing.T) {
t.Parallel()
app := setupApp(func(c fiber.Ctx) error {
key, err := APIKeyCookie(c, "api")
if err != nil {
return err
}
return c.SendString(key)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "api", Value: "secret"})
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Equal(t, "secret", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
badApp := setupApp(func(c fiber.Ctx) error {
_, err := APIKeyCookie(c, "")
return err
})
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = badApp.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_APIKeyHeader(t *testing.T) {
t.Parallel()
app := setupApp(func(c fiber.Ctx) error {
key, err := APIKeyHeader(c, "X-API-Key")
if err != nil {
return err
}
return c.SendString(key)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "secret")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Equal(t, "secret", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
badApp := setupApp(func(c fiber.Ctx) error {
_, err := APIKeyHeader(c, "")
return err
})
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = badApp.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_APIKeyQuery(t *testing.T) {
t.Parallel()
app := setupApp(func(c fiber.Ctx) error {
key, err := APIKeyQuery(c, "key")
if err != nil {
return err
}
return c.SendString(key)
})
req := httptest.NewRequest(fiber.MethodGet, "/?key=secret", nil)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Equal(t, "secret", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
badApp := setupApp(func(c fiber.Ctx) error {
_, err := APIKeyQuery(c, "")
return err
})
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = badApp.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_GetAuthorizationCredentials(t *testing.T) {
t.Parallel()
app := setupApp(func(c fiber.Ctx) error {
cred, err := GetAuthorizationCredentials(c)
if err != nil {
return err
}
return c.JSON(cred)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "Bearer")
require.Contains(t, string(body), "token")
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "badheader")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
}
func Test_HTTPBearer(t *testing.T) {
t.Parallel()
app := setupApp(func(c fiber.Ctx) error {
token, err := HTTPBearer(c)
if err != nil {
return err
}
return c.SendString(token)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer tok")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Equal(t, "tok", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Basic foo")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
}
func Test_HTTPBasic(t *testing.T) {
t.Parallel()
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
app := setupApp(func(c fiber.Ctx) error {
cred, err := HTTPBasic(c)
if err != nil {
return err
}
return c.JSON(cred)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "john")
require.Contains(t, string(body), "doe")
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer token")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
bad := setupApp(func(c fiber.Ctx) error {
_, err := HTTPBasic(c)
return err
})
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Basic !!")
resp, err = bad.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
}
func Test_HTTPDigest(t *testing.T) {
t.Parallel()
app := setupApp(func(c fiber.Ctx) error {
token, err := HTTPDigest(c)
if err != nil {
return err
}
return c.SendString(token)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Digest abc")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
require.Equal(t, "abc", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer xyz")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
}