mirror of
https://github.com/gofiber/fiber.git
synced 2025-05-31 11:52:41 +00:00
257 lines
6.8 KiB
Go
257 lines
6.8 KiB
Go
package security
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func setupApp(handler fiber.Handler) *fiber.App {
|
|
app := fiber.New()
|
|
app.Get("/", handler)
|
|
return app
|
|
}
|
|
|
|
func Test_APIKeyCookie(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
key, err := APIKeyCookie(c, "api")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.SendString(key)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.AddCookie(&http.Cookie{Name: "api", Value: "secret"})
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Equal(t, "secret", string(body))
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
badApp := setupApp(func(c fiber.Ctx) error {
|
|
_, err := APIKeyCookie(c, "")
|
|
return err
|
|
})
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = badApp.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
|
|
func Test_APIKeyHeader(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
key, err := APIKeyHeader(c, "X-API-Key")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.SendString(key)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("X-API-Key", "secret")
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Equal(t, "secret", string(body))
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
badApp := setupApp(func(c fiber.Ctx) error {
|
|
_, err := APIKeyHeader(c, "")
|
|
return err
|
|
})
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = badApp.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
|
|
func Test_APIKeyQuery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
key, err := APIKeyQuery(c, "key")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.SendString(key)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/?key=secret", nil)
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Equal(t, "secret", string(body))
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
badApp := setupApp(func(c fiber.Ctx) error {
|
|
_, err := APIKeyQuery(c, "")
|
|
return err
|
|
})
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = badApp.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
|
|
func Test_GetAuthorizationCredentials(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
cred, err := GetAuthorizationCredentials(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.JSON(cred)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer token")
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Contains(t, string(body), "Bearer")
|
|
require.Contains(t, string(body), "token")
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "badheader")
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
func Test_HTTPBearer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
token, err := HTTPBearer(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.SendString(token)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer tok")
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Equal(t, "tok", string(body))
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Basic foo")
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
func Test_HTTPBasic(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
cred, err := HTTPBasic(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.JSON(cred)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Basic "+creds)
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Contains(t, string(body), "john")
|
|
require.Contains(t, string(body), "doe")
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer token")
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
bad := setupApp(func(c fiber.Ctx) error {
|
|
_, err := HTTPBasic(c)
|
|
return err
|
|
})
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Basic !!")
|
|
resp, err = bad.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
func Test_HTTPDigest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
app := setupApp(func(c fiber.Ctx) error {
|
|
token, err := HTTPDigest(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.SendString(token)
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Digest abc")
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
body, _ := io.ReadAll(resp.Body)
|
|
require.Equal(t, "abc", string(body))
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
|
|
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer xyz")
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|