diff --git a/docs/guide/security.md b/docs/guide/security.md new file mode 100644 index 00000000..b7ca1b4e --- /dev/null +++ b/docs/guide/security.md @@ -0,0 +1,63 @@ +--- +id: security +title: 🔒 Security Helpers +sidebar_position: 8 +--- + +Fiber provides helper functions for common security tasks like extracting API keys or credentials from a request. +These utilities can be used with your middleware or handlers. + +## API Key helpers + +```go +import "github.com/gofiber/fiber/v3/security" + +func handler(c fiber.Ctx) error { + key, err := security.APIKeyHeader(c, "X-API-Key") + if err != nil { + return err + } + // use key + return nil +} +``` + +Available helpers: + +- `APIKeyCookie(c fiber.Ctx, name string)` +- `APIKeyHeader(c fiber.Ctx, header string)` +- `APIKeyQuery(c fiber.Ctx, name string)` + +Each returns the key or `fiber.ErrUnauthorized` when the key is missing. + +```go +// Cookie +key, _ := security.APIKeyCookie(c, "session") + +// Query parameter +key, _ = security.APIKeyQuery(c, "api_key") +``` + +## Authorization helpers + +```go +cred, err := security.GetAuthorizationCredentials(c) +``` + +Use `HTTPBearer`, `HTTPBasic`, or `HTTPDigest` to parse common Authorization schemes. + +```go +bearer, err := security.HTTPBearer(c) +``` + +```go +user, err := security.HTTPBasic(c) +``` + +```go +digest, err := security.HTTPDigest(c) +``` + +`HTTPBasic` returns `HTTPBasicCredentials` containing the parsed username and password. + + diff --git a/security/security.go b/security/security.go new file mode 100644 index 00000000..df078f74 --- /dev/null +++ b/security/security.go @@ -0,0 +1,115 @@ +package security + +import ( + "encoding/base64" + "strings" + + "github.com/gofiber/fiber/v3" +) + +// APIKeyCookie retrieves an API key from the named cookie. +// It returns ErrBadRequest if the cookie name is empty +// and ErrUnauthorized when the cookie does not exist. +func APIKeyCookie(c fiber.Ctx, name string) (string, error) { + if name == "" { + return "", fiber.NewError(fiber.StatusBadRequest, "name is empty") + } + key := c.Cookies(name) + if key == "" { + return "", fiber.ErrUnauthorized + } + return key, nil +} + +// APIKeyHeader retrieves an API key from the named header. +// It returns ErrBadRequest if the header name is empty +// and ErrUnauthorized when the header is missing. +func APIKeyHeader(c fiber.Ctx, header string) (string, error) { + if header == "" { + return "", fiber.NewError(fiber.StatusBadRequest, "header is empty") + } + key := c.Get(header) + if key == "" { + return "", fiber.ErrUnauthorized + } + return key, nil +} + +// APIKeyQuery retrieves an API key from the given query parameter. +// It returns ErrBadRequest if the query name is empty +// and ErrUnauthorized when the parameter is missing. +func APIKeyQuery(c fiber.Ctx, name string) (string, error) { + if name == "" { + return "", fiber.NewError(fiber.StatusBadRequest, "name is empty") + } + key := fiber.Query[string](c, name) + if key == "" { + return "", fiber.ErrUnauthorized + } + return key, nil +} + +// HTTPAuthorizationCredentials represents the Authorization header parts. +type HTTPAuthorizationCredentials struct { + Scheme string + Token string +} + +// GetAuthorizationCredentials parses the Authorization header. +func GetAuthorizationCredentials(c fiber.Ctx) (HTTPAuthorizationCredentials, error) { + auth := c.Get(fiber.HeaderAuthorization) + if auth == "" { + return HTTPAuthorizationCredentials{}, fiber.ErrUnauthorized + } + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 { + return HTTPAuthorizationCredentials{}, fiber.ErrUnauthorized + } + return HTTPAuthorizationCredentials{Scheme: parts[0], Token: parts[1]}, nil +} + +// HTTPBearer extracts a bearer token from the Authorization header. +func HTTPBearer(c fiber.Ctx) (string, error) { + auth := c.Get(fiber.HeaderAuthorization) + if auth == "" { + return "", fiber.ErrUnauthorized + } + const prefix = "Bearer " + if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "", fiber.ErrUnauthorized + } + return auth[len(prefix):], nil +} + +// HTTPBasicCredentials holds parsed HTTP basic auth credentials. +type HTTPBasicCredentials struct { + Username string + Password string +} + +// HTTPBasic parses the Authorization header for basic auth credentials. +func HTTPBasic(c fiber.Ctx) (HTTPBasicCredentials, error) { + auth := c.Get(fiber.HeaderAuthorization) + if len(auth) <= 6 || !strings.EqualFold(auth[:6], "Basic ") { + return HTTPBasicCredentials{}, fiber.ErrUnauthorized + } + decoded, err := base64.StdEncoding.DecodeString(auth[6:]) + if err != nil { + return HTTPBasicCredentials{}, fiber.ErrUnauthorized + } + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) != 2 { + return HTTPBasicCredentials{}, fiber.ErrUnauthorized + } + return HTTPBasicCredentials{Username: parts[0], Password: parts[1]}, nil +} + +// HTTPDigest retrieves the digest value from the Authorization header. +func HTTPDigest(c fiber.Ctx) (string, error) { + auth := c.Get(fiber.HeaderAuthorization) + const prefix = "Digest " + if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "", fiber.ErrUnauthorized + } + return auth[len(prefix):], nil +} diff --git a/security/security_test.go b/security/security_test.go new file mode 100644 index 00000000..7fc6a84a --- /dev/null +++ b/security/security_test.go @@ -0,0 +1,256 @@ +package security + +import ( + "encoding/base64" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func setupApp(handler fiber.Handler) *fiber.App { + app := fiber.New() + app.Get("/", handler) + return app +} + +func Test_APIKeyCookie(t *testing.T) { + t.Parallel() + + app := setupApp(func(c fiber.Ctx) error { + key, err := APIKeyCookie(c, "api") + if err != nil { + return err + } + return c.SendString(key) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "api", Value: "secret"}) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Equal(t, "secret", string(body)) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + badApp := setupApp(func(c fiber.Ctx) error { + _, err := APIKeyCookie(c, "") + return err + }) + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = badApp.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_APIKeyHeader(t *testing.T) { + t.Parallel() + + app := setupApp(func(c fiber.Ctx) error { + key, err := APIKeyHeader(c, "X-API-Key") + if err != nil { + return err + } + return c.SendString(key) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "secret") + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Equal(t, "secret", string(body)) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + badApp := setupApp(func(c fiber.Ctx) error { + _, err := APIKeyHeader(c, "") + return err + }) + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = badApp.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_APIKeyQuery(t *testing.T) { + t.Parallel() + + app := setupApp(func(c fiber.Ctx) error { + key, err := APIKeyQuery(c, "key") + if err != nil { + return err + } + return c.SendString(key) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/?key=secret", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Equal(t, "secret", string(body)) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + badApp := setupApp(func(c fiber.Ctx) error { + _, err := APIKeyQuery(c, "") + return err + }) + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = badApp.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_GetAuthorizationCredentials(t *testing.T) { + t.Parallel() + + app := setupApp(func(c fiber.Ctx) error { + cred, err := GetAuthorizationCredentials(c) + if err != nil { + return err + } + return c.JSON(cred) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer token") + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Contains(t, string(body), "Bearer") + require.Contains(t, string(body), "token") + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "badheader") + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) +} + +func Test_HTTPBearer(t *testing.T) { + t.Parallel() + + app := setupApp(func(c fiber.Ctx) error { + token, err := HTTPBearer(c) + if err != nil { + return err + } + return c.SendString(token) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer tok") + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Equal(t, "tok", string(body)) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Basic foo") + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) +} + +func Test_HTTPBasic(t *testing.T) { + t.Parallel() + + creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) + app := setupApp(func(c fiber.Ctx) error { + cred, err := HTTPBasic(c) + if err != nil { + return err + } + return c.JSON(cred) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Basic "+creds) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Contains(t, string(body), "john") + require.Contains(t, string(body), "doe") + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer token") + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + bad := setupApp(func(c fiber.Ctx) error { + _, err := HTTPBasic(c) + return err + }) + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Basic !!") + resp, err = bad.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) +} + +func Test_HTTPDigest(t *testing.T) { + t.Parallel() + + app := setupApp(func(c fiber.Ctx) error { + token, err := HTTPDigest(c) + if err != nil { + return err + } + return c.SendString(token) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Digest abc") + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + require.Equal(t, "abc", string(body)) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) + + req = httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer xyz") + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) +}