fiber/middleware/csrf/config_test.go
2025-08-01 16:57:52 +02:00

408 lines
9.7 KiB
Go

package csrf
import (
"fmt"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// Test security validation functions
func Test_CSRF_ExtractorSecurity_Validation(t *testing.T) {
t.Parallel()
// Test secure configurations - should not panic
t.Run("SecureConfigurations", func(t *testing.T) {
t.Parallel()
secureConfigs := []Config{
{Extractor: FromHeader("X-Csrf-Token")},
{Extractor: FromForm("_csrf")},
{Extractor: FromQuery("csrf_token")},
{Extractor: FromParam("csrf")},
{Extractor: Chain(FromHeader("X-Csrf-Token"), FromForm("_csrf"))},
}
for i, cfg := range secureConfigs {
t.Run(fmt.Sprintf("Config%d", i), func(t *testing.T) {
require.NotPanics(t, func() {
configDefault(cfg)
})
})
}
})
// Test insecure configurations - should panic
t.Run("InsecureCookieExtractor", func(t *testing.T) {
t.Parallel()
// Create a custom extractor that reads from cookie (simulating dangerous behavior)
insecureCookieExtractor := Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("csrf_"), nil
},
Source: SourceCookie,
Key: "csrf_",
}
cfg := Config{
CookieName: "csrf_",
Extractor: insecureCookieExtractor,
}
require.Panics(t, func() {
configDefault(cfg)
}, "Should panic when extractor reads from same cookie")
})
// Test insecure chained extractors
t.Run("InsecureChainedExtractor", func(t *testing.T) {
t.Parallel()
insecureCookieExtractor := Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("csrf_"), nil
},
Source: SourceCookie,
Key: "csrf_",
}
chainedExtractor := Chain(
FromHeader("X-Csrf-Token"),
insecureCookieExtractor, // This should trigger panic
)
cfg := Config{
CookieName: "csrf_",
Extractor: chainedExtractor,
}
require.Panics(t, func() {
configDefault(cfg)
}, "Should panic when chained extractor reads from same cookie")
})
// Test different cookie names - should be secure
t.Run("DifferentCookieNames", func(t *testing.T) {
t.Parallel()
cookieExtractor := Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("different_cookie"), nil
},
Source: SourceCookie,
Key: "different_cookie",
}
cfg := Config{
CookieName: "csrf_",
Extractor: cookieExtractor,
}
require.NotPanics(t, func() {
configDefault(cfg)
}, "Should not panic when extractor reads from different cookie")
})
}
// Test extractor metadata
func Test_CSRF_Extractor_Metadata(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
expectedKey string
extractor Extractor
expectedSource Source
}{
{
name: "FromHeader",
extractor: FromHeader("X-Custom-Token"),
expectedSource: SourceHeader,
expectedKey: "X-Custom-Token",
},
{
name: "FromForm",
extractor: FromForm("_token"),
expectedSource: SourceForm,
expectedKey: "_token",
},
{
name: "FromQuery",
extractor: FromQuery("token"),
expectedSource: SourceQuery,
expectedKey: "token",
},
{
name: "FromParam",
extractor: FromParam("id"),
expectedSource: SourceParam,
expectedKey: "id",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.expectedSource, tc.extractor.Source)
require.Equal(t, tc.expectedKey, tc.extractor.Key)
require.NotNil(t, tc.extractor.Extract)
})
}
}
// Test chain extractor metadata
func Test_CSRF_Chain_Extractor_Metadata(t *testing.T) {
t.Parallel()
t.Run("EmptyChain", func(t *testing.T) {
t.Parallel()
chained := Chain()
require.Equal(t, SourceCustom, chained.Source)
require.Equal(t, "", chained.Key)
require.Empty(t, chained.Chain)
})
t.Run("SingleExtractor", func(t *testing.T) {
t.Parallel()
header := FromHeader("X-Token")
chained := Chain(header)
require.Equal(t, SourceHeader, chained.Source)
require.Equal(t, "X-Token", chained.Key)
require.Len(t, chained.Chain, 1)
})
t.Run("MultipleExtractors", func(t *testing.T) {
t.Parallel()
header := FromHeader("X-Token")
form := FromForm("_csrf")
chained := Chain(header, form)
// Should use first extractor's metadata
require.Equal(t, SourceHeader, chained.Source)
require.Equal(t, "X-Token", chained.Key)
require.Len(t, chained.Chain, 2)
require.Equal(t, header.Source, chained.Chain[0].Source)
require.Equal(t, form.Source, chained.Chain[1].Source)
})
}
// Test custom extractor with new struct pattern
func Test_CSRF_Custom_Extractor_Struct(t *testing.T) {
t.Parallel()
app := fiber.New()
// Custom extractor using new struct pattern
customExtractor := Extractor{
Extract: func(c fiber.Ctx) (string, error) {
// Extract from custom header
token := c.Get("X-Custom-CSRF")
if token == "" {
return "", ErrMissingHeader
}
return token, nil
},
Source: SourceCustom,
Key: "X-Custom-CSRF",
}
app.Use(New(Config{Extractor: customExtractor}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendString("OK")
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test with custom header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Custom-CSRF", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test without custom header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
// Test error types for different extractors
func Test_CSRF_Extractor_Error_Types(t *testing.T) {
t.Parallel()
testCases := []struct {
expectedError error
setupRequest func(*fasthttp.RequestCtx)
name string
extractor Extractor
}{
{
name: "MissingHeader",
extractor: FromHeader("X-Missing"),
setupRequest: func(_ *fasthttp.RequestCtx) {
// Don't set the header
},
expectedError: ErrMissingHeader,
},
{
name: "MissingForm",
extractor: FromForm("_missing"),
setupRequest: func(ctx *fasthttp.RequestCtx) {
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
// Don't set form data
},
expectedError: ErrMissingForm,
},
{
name: "MissingQuery",
extractor: FromQuery("missing"),
setupRequest: func(ctx *fasthttp.RequestCtx) {
ctx.Request.SetRequestURI("/")
// Don't set query param
},
expectedError: ErrMissingQuery,
},
{
name: "MissingParam",
extractor: FromParam("missing"),
setupRequest: func(_ *fasthttp.RequestCtx) {
// This would need special route setup to test properly
// For now, we'll test the extractor directly
},
expectedError: ErrMissingParam,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := &fasthttp.RequestCtx{}
tc.setupRequest(ctx)
c := app.AcquireCtx(ctx)
_, err := tc.extractor.Extract(c)
require.Error(t, err)
require.Equal(t, tc.expectedError, err)
app.ReleaseCtx(c)
})
}
}
// Test security warning logs (would need to capture log output in real implementation)
func Test_CSRF_Security_Warnings(t *testing.T) {
t.Parallel()
// Test that insecure extractors trigger warnings
// Note: In a real implementation, you'd want to capture log output
// For now, we just test that the configuration doesn't panic
insecureConfigs := []Config{
{Extractor: FromQuery("csrf_token")},
{Extractor: FromParam("csrf")},
}
for i, cfg := range insecureConfigs {
t.Run(fmt.Sprintf("InsecureConfig%d", i), func(t *testing.T) {
t.Parallel()
require.NotPanics(t, func() {
configDefault(cfg)
})
})
}
}
// Test isInsecureCookieExtractor function directly
func Test_isInsecureCookieExtractor(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
cookieName string
extractor Extractor
expected bool
}{
{
name: "SecureHeaderExtractor",
extractor: Extractor{
Source: SourceHeader,
Key: "X-Csrf-Token",
},
cookieName: "csrf_",
expected: false,
},
{
name: "InsecureCookieExtractor",
extractor: Extractor{
Source: SourceCookie,
Key: "csrf_",
},
cookieName: "csrf_",
expected: true,
},
{
name: "CookieExtractorDifferentName",
extractor: Extractor{
Source: SourceCookie,
Key: "different_cookie",
},
cookieName: "csrf_",
expected: false,
},
{
name: "CustomExtractorSafeName",
extractor: Extractor{
Source: SourceCustom,
Key: "safe_key",
},
cookieName: "csrf_",
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := isInsecureCookieExtractor(tc.extractor, tc.cookieName)
require.Equal(t, tc.expected, result)
})
}
}
func Test_CSRF_CookieName_CaseInsensitive_Warning(t *testing.T) {
t.Parallel()
// Extractor uses "CSRF_" (uppercase), config uses "csrf_" (lowercase)
extractor := Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("CSRF_"), nil
},
Source: SourceCookie,
Key: "CSRF_",
}
cfg := Config{
CookieName: "csrf_",
Extractor: extractor,
}
// Should not panic, but should log a warning
require.NotPanics(t, func() {
configDefault(cfg)
}, "Should not panic for case-insensitive cookie name match, but should warn")
}