mirror of https://github.com/gofiber/fiber.git
refactor(timeout): unify and enhance timeout middleware (#3275)
* feat(timeout): unify and enhance timeout middleware - Combine classic context-based timeout with a Goroutine + channel approach - Support custom error list without additional parameters - Return fiber.ErrRequestTimeout for timeouts or listed errors * feat(timeout): unify and enhance timeout middleware - Combine classic context-based timeout with a Goroutine + channel approach - Support custom error list without additional parameters - Return fiber.ErrRequestTimeout for timeouts or listed errors * refactor(timeout): remove goroutine-based logic and improve documentation - Switch to a synchronous approach to avoid data races with fasthttp context - Enhance error handling for deadline and custom errors - Update comments for clarity and maintainability * refactor(timeout): add more test cases and handle zero duration case * refactor(timeout): add more test cases and handle zero duration case * refactor(timeout): add more test cases and handle zero duration case --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>pull/3277/head
parent
86d72bbba8
commit
bc37f209bf
|
@ -8,23 +8,52 @@ import (
|
|||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
|
||||
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
|
||||
// New enforces a timeout for each incoming request. If the timeout expires or
|
||||
// any of the specified errors occur, fiber.ErrRequestTimeout is returned.
|
||||
func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
|
||||
return func(ctx fiber.Ctx) error {
|
||||
timeoutContext, cancel := context.WithTimeout(ctx.Context(), t)
|
||||
defer cancel()
|
||||
ctx.SetContext(timeoutContext)
|
||||
if err := h(ctx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return fiber.ErrRequestTimeout
|
||||
}
|
||||
for i := range tErrs {
|
||||
if errors.Is(err, tErrs[i]) {
|
||||
return fiber.ErrRequestTimeout
|
||||
}
|
||||
}
|
||||
return err
|
||||
// If timeout <= 0, skip context.WithTimeout and run the handler as-is.
|
||||
if timeout <= 0 {
|
||||
return runHandler(ctx, h, tErrs)
|
||||
}
|
||||
return nil
|
||||
|
||||
// Create a context with the specified timeout; any operation exceeding
|
||||
// this deadline will be canceled automatically.
|
||||
timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Replace the default Fiber context with our timeout-bound context.
|
||||
ctx.SetContext(timeoutContext)
|
||||
|
||||
// Run the handler and check for relevant errors.
|
||||
err := runHandler(ctx, h, tErrs)
|
||||
|
||||
// If the context actually timed out, return a timeout error.
|
||||
if errors.Is(timeoutContext.Err(), context.DeadlineExceeded) {
|
||||
return fiber.ErrRequestTimeout
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// runHandler executes the handler and returns fiber.ErrRequestTimeout if it
|
||||
// sees a deadline exceeded error or one of the custom "timeout-like" errors.
|
||||
func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error {
|
||||
// Execute the wrapped handler synchronously.
|
||||
err := h(c)
|
||||
// If the context has timed out, return a request timeout error.
|
||||
if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) {
|
||||
return fiber.ErrRequestTimeout
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// isCustomError checks whether err matches any error in errList using errors.Is.
|
||||
func isCustomError(err error, errList []error) bool {
|
||||
for _, e := range errList {
|
||||
if errors.Is(err, e) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -12,77 +12,119 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// go test -run Test_WithContextTimeout
|
||||
func Test_WithContextTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
h := New(func(c fiber.Ctx) error {
|
||||
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
require.NoError(t, err)
|
||||
if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil {
|
||||
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
|
||||
}
|
||||
return nil
|
||||
}, 100*time.Millisecond)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
require.NoError(t, err, "app.Test(req)")
|
||||
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
require.NoError(t, err, "app.Test(req)")
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
testTimeout("300")
|
||||
testTimeout("500")
|
||||
testSucces("50")
|
||||
testSucces("30")
|
||||
}
|
||||
var (
|
||||
// Custom error that we treat like a timeout when returned by the handler.
|
||||
errCustomTimeout = errors.New("custom timeout error")
|
||||
|
||||
var ErrFooTimeOut = errors.New("foo context canceled")
|
||||
|
||||
// go test -run Test_WithContextTimeoutWithCustomError
|
||||
func Test_WithContextTimeoutWithCustomError(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
h := New(func(c fiber.Ctx) error {
|
||||
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
require.NoError(t, err)
|
||||
if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil {
|
||||
return fmt.Errorf("%w: execution error", err)
|
||||
}
|
||||
return nil
|
||||
}, 100*time.Millisecond, ErrFooTimeOut)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
require.NoError(t, err, "app.Test(req)")
|
||||
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
||||
require.NoError(t, err, "app.Test(req)")
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
testTimeout("300")
|
||||
testTimeout("500")
|
||||
testSucces("50")
|
||||
testSucces("30")
|
||||
}
|
||||
// Some unrelated error that should NOT trigger a request timeout.
|
||||
errUnrelated = errors.New("unmatched error")
|
||||
)
|
||||
|
||||
// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled.
|
||||
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop() // Clean up the timer
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return te
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestTimeout_Success tests a handler that completes within the allotted timeout.
|
||||
func TestTimeout_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
// Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit.
|
||||
app.Get("/fast", New(func(c fiber.Ctx) error {
|
||||
// Simulate some work
|
||||
if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SendString("OK")
|
||||
}, 50*time.Millisecond))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/fast", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err, "app.Test(req) should not fail")
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests")
|
||||
}
|
||||
|
||||
// TestTimeout_Exceeded tests a handler that exceeds the provided timeout.
|
||||
func TestTimeout_Exceeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
// This handler sleeps 200ms, exceeding the 100ms limit.
|
||||
app.Get("/slow", New(func(c fiber.Ctx) error {
|
||||
if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.SendString("Should never get here")
|
||||
}, 100*time.Millisecond))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/slow", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err, "app.Test(req) should not fail")
|
||||
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout")
|
||||
}
|
||||
|
||||
// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout.
|
||||
func TestTimeout_CustomError(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
// This handler sleeps 50ms and returns errCustomTimeout if canceled.
|
||||
app.Get("/custom", New(func(c fiber.Ctx) error {
|
||||
// Sleep might time out, or might return early. If the context is canceled,
|
||||
// we treat errCustomTimeout as a 'timeout-like' condition.
|
||||
if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil {
|
||||
return fmt.Errorf("wrapped: %w", err)
|
||||
}
|
||||
return c.SendString("Should never get here")
|
||||
}, 100*time.Millisecond, errCustomTimeout))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/custom", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err, "app.Test(req) should not fail")
|
||||
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error")
|
||||
}
|
||||
|
||||
// TestTimeout_UnmatchedError checks that if the handler returns an error
|
||||
// that is neither a deadline exceeded nor a custom 'timeout' error, it is
|
||||
// propagated as a regular 500 (internal server error).
|
||||
func TestTimeout_UnmatchedError(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/unmatched", New(func(_ fiber.Ctx) error {
|
||||
return errUnrelated // Not in the custom error list
|
||||
}, 100*time.Millisecond, errCustomTimeout))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err, "app.Test(req) should not fail")
|
||||
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode,
|
||||
"Expected 500 because the error is not recognized as a timeout error")
|
||||
}
|
||||
|
||||
// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero.
|
||||
// Usually this means the request can never exceed a 'deadline' – effectively no timeout.
|
||||
func TestTimeout_ZeroDuration(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/zero", New(func(c fiber.Ctx) error {
|
||||
// Sleep 50ms, but there's no real 'deadline' since zero-timeout.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return c.SendString("No timeout used")
|
||||
}, 0))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/zero", nil)
|
||||
resp, err := app.Test(req)
|
||||
require.NoError(t, err, "app.Test(req) should not fail")
|
||||
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue