mirror of
https://github.com/gofiber/fiber.git
synced 2025-07-11 13:08:04 +00:00
263 lines
6.5 KiB
Go
263 lines
6.5 KiB
Go
package earlydata
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
const (
|
|
headerName = "Early-Data"
|
|
headerValOn = "1"
|
|
headerValOff = "0"
|
|
)
|
|
|
|
func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {
|
|
t.Helper()
|
|
t.Parallel()
|
|
|
|
var app *fiber.App
|
|
if c == nil {
|
|
app = fiber.New()
|
|
} else {
|
|
app = fiber.New(*c)
|
|
}
|
|
|
|
app.Use(New())
|
|
|
|
// Middleware to test IsEarly func
|
|
const localsKeyTestValid = "earlydata_testvalid"
|
|
app.Use(func(c fiber.Ctx) error {
|
|
isEarly := IsEarly(c)
|
|
|
|
switch h := c.Get(headerName); h {
|
|
case "", headerValOff:
|
|
if isEarly {
|
|
return errors.New("is early-data even though it's not")
|
|
}
|
|
|
|
case headerValOn:
|
|
switch {
|
|
case fiber.IsMethodSafe(c.Method()):
|
|
if !isEarly {
|
|
return errors.New("should be early-data on safe HTTP methods")
|
|
}
|
|
default:
|
|
if isEarly {
|
|
return errors.New("early-data unsupported on unsafe HTTP methods")
|
|
}
|
|
}
|
|
|
|
default:
|
|
return fmt.Errorf("header has unsupported value: %s", h)
|
|
}
|
|
|
|
_ = c.Locals(localsKeyTestValid, true)
|
|
|
|
return c.Next()
|
|
})
|
|
|
|
app.Add([]string{
|
|
fiber.MethodGet,
|
|
fiber.MethodPost,
|
|
}, "/", func(c fiber.Ctx) error {
|
|
valid, ok := c.Locals(localsKeyTestValid).(bool)
|
|
if !ok {
|
|
panic(errors.New("failed to type-assert to bool"))
|
|
}
|
|
if !valid {
|
|
return errors.New("handler called even though validation failed")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return app
|
|
}
|
|
|
|
// go test -run Test_EarlyData
|
|
func Test_EarlyData(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
trustedRun := func(t *testing.T, app *fiber.App) {
|
|
t.Helper()
|
|
|
|
{
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOff)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOn)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
{
|
|
req := httptest.NewRequest(fiber.MethodPost, "/", nil)
|
|
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOff)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOn)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
untrustedRun := func(t *testing.T, app *fiber.App) {
|
|
t.Helper()
|
|
|
|
{
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOff)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOn)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
}
|
|
|
|
{
|
|
req := httptest.NewRequest(fiber.MethodPost, "/", nil)
|
|
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOff)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
|
|
req.Header.Set(headerName, headerValOn)
|
|
resp, err = app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
t.Run("empty config", func(t *testing.T) {
|
|
app := appWithConfig(t, nil)
|
|
trustedRun(t, app)
|
|
})
|
|
t.Run("default config", func(t *testing.T) {
|
|
app := appWithConfig(t, &fiber.Config{})
|
|
trustedRun(t, app)
|
|
})
|
|
|
|
t.Run("config with TrustProxy", func(t *testing.T) {
|
|
app := appWithConfig(t, &fiber.Config{
|
|
TrustProxy: true,
|
|
})
|
|
untrustedRun(t, app)
|
|
})
|
|
t.Run("config with TrustProxy and trusted TrustProxyConfig.Proxies", func(t *testing.T) {
|
|
app := appWithConfig(t, &fiber.Config{
|
|
TrustProxy: true,
|
|
TrustProxyConfig: fiber.TrustProxyConfig{
|
|
Proxies: []string{
|
|
"0.0.0.0",
|
|
},
|
|
},
|
|
})
|
|
trustedRun(t, app)
|
|
})
|
|
}
|
|
|
|
// Test_EarlyDataNext verifies that the middleware skips its logic when Next returns true.
|
|
func Test_EarlyDataNext(t *testing.T) {
|
|
t.Parallel()
|
|
app := fiber.New()
|
|
|
|
app.Use(New(Config{
|
|
Next: func(fiber.Ctx) bool { return true },
|
|
}))
|
|
|
|
called := false
|
|
app.Get("/", func(c fiber.Ctx) error {
|
|
called = true
|
|
if IsEarly(c) {
|
|
return errors.New("IsEarly(c) should be false when Next returns true")
|
|
}
|
|
return nil
|
|
})
|
|
|
|
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
|
req.Header.Set(headerName, headerValOn)
|
|
resp, err := app.Test(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fiber.StatusOK, resp.StatusCode)
|
|
require.True(t, called)
|
|
}
|
|
|
|
// Test_configDefault_NoConfig verifies that calling configDefault without
|
|
// providing a configuration returns ConfigDefault as-is.
|
|
func Test_configDefault_NoConfig(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := configDefault()
|
|
require.Equal(t, ConfigDefault.Error, cfg.Error)
|
|
require.Equal(t, reflect.ValueOf(ConfigDefault.IsEarlyData).Pointer(), reflect.ValueOf(cfg.IsEarlyData).Pointer())
|
|
require.Equal(t, reflect.ValueOf(ConfigDefault.AllowEarlyData).Pointer(), reflect.ValueOf(cfg.AllowEarlyData).Pointer())
|
|
}
|
|
|
|
// Test_configDefault_WithConfig verifies that provided configuration fields are
|
|
// kept while missing fields are populated with defaults.
|
|
func Test_configDefault_WithConfig(t *testing.T) {
|
|
t.Parallel()
|
|
expectedErr := errors.New("boom")
|
|
called := false
|
|
custom := Config{
|
|
Next: func(_ fiber.Ctx) bool { called = true; return false },
|
|
Error: expectedErr,
|
|
}
|
|
|
|
cfg := configDefault(custom)
|
|
|
|
// Next should be preserved and not invoked by configDefault.
|
|
require.False(t, called)
|
|
require.Equal(t, reflect.ValueOf(custom.Next).Pointer(), reflect.ValueOf(cfg.Next).Pointer())
|
|
// Custom error must be preserved.
|
|
require.Equal(t, expectedErr, cfg.Error)
|
|
// Missing fields should be set to defaults.
|
|
require.NotNil(t, cfg.IsEarlyData)
|
|
require.NotNil(t, cfg.AllowEarlyData)
|
|
|
|
// Verify default functions behave as expected.
|
|
app := fiber.New()
|
|
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
|
c.Request().Header.Set(DefaultHeaderName, DefaultHeaderTrueValue)
|
|
c.Request().Header.SetMethod(fiber.MethodGet)
|
|
require.True(t, cfg.IsEarlyData(c))
|
|
require.True(t, cfg.AllowEarlyData(c))
|
|
app.ReleaseCtx(c)
|
|
}
|