refactor(timeout): unify and enhance timeout middleware (#3275)

* feat(timeout): unify and enhance timeout middleware

- Combine classic context-based timeout with a Goroutine + channel approach
- Support custom error list without additional parameters
- Return fiber.ErrRequestTimeout for timeouts or listed errors

* feat(timeout): unify and enhance timeout middleware

- Combine classic context-based timeout with a Goroutine + channel approach
- Support custom error list without additional parameters
- Return fiber.ErrRequestTimeout for timeouts or listed errors

* refactor(timeout): remove goroutine-based logic and improve documentation

- Switch to a synchronous approach to avoid data races with fasthttp context
- Enhance error handling for deadline and custom errors
- Update comments for clarity and maintainability

* refactor(timeout): add more test cases and handle zero duration case

* refactor(timeout): add more test cases and handle zero duration case

* refactor(timeout): add more test cases and handle zero duration case

---------

Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
pull/3277/head
RW 2025-01-08 08:19:20 +01:00 committed by GitHub
parent 86d72bbba8
commit bc37f209bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 151 additions and 80 deletions

View File

@ -8,23 +8,52 @@ import (
"github.com/gofiber/fiber/v3"
)
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
// New enforces a timeout for each incoming request. If the timeout expires or
// any of the specified errors occur, fiber.ErrRequestTimeout is returned.
func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
return func(ctx fiber.Ctx) error {
timeoutContext, cancel := context.WithTimeout(ctx.Context(), t)
defer cancel()
ctx.SetContext(timeoutContext)
if err := h(ctx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
for i := range tErrs {
if errors.Is(err, tErrs[i]) {
return fiber.ErrRequestTimeout
}
}
return err
// If timeout <= 0, skip context.WithTimeout and run the handler as-is.
if timeout <= 0 {
return runHandler(ctx, h, tErrs)
}
return nil
// Create a context with the specified timeout; any operation exceeding
// this deadline will be canceled automatically.
timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout)
defer cancel()
// Replace the default Fiber context with our timeout-bound context.
ctx.SetContext(timeoutContext)
// Run the handler and check for relevant errors.
err := runHandler(ctx, h, tErrs)
// If the context actually timed out, return a timeout error.
if errors.Is(timeoutContext.Err(), context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
return err
}
}
// runHandler executes the handler and returns fiber.ErrRequestTimeout if it
// sees a deadline exceeded error or one of the custom "timeout-like" errors.
func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error {
// Execute the wrapped handler synchronously.
err := h(c)
// If the context has timed out, return a request timeout error.
if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) {
return fiber.ErrRequestTimeout
}
return err
}
// isCustomError checks whether err matches any error in errList using errors.Is.
func isCustomError(err error, errList []error) bool {
for _, e := range errList {
if errors.Is(err, e) {
return true
}
}
return false
}

View File

@ -12,77 +12,119 @@ import (
"github.com/stretchr/testify/require"
)
// go test -run Test_WithContextTimeout
func Test_WithContextTimeout(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
h := New(func(c fiber.Ctx) error {
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
require.NoError(t, err)
if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
return nil
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}
var (
// Custom error that we treat like a timeout when returned by the handler.
errCustomTimeout = errors.New("custom timeout error")
var ErrFooTimeOut = errors.New("foo context canceled")
// go test -run Test_WithContextTimeoutWithCustomError
func Test_WithContextTimeoutWithCustomError(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
h := New(func(c fiber.Ctx) error {
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
require.NoError(t, err)
if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}
// Some unrelated error that should NOT trigger a request timeout.
errUnrelated = errors.New("unmatched error")
)
// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled.
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
defer timer.Stop() // Clean up the timer
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return te
case <-timer.C:
return nil
}
return nil
}
// TestTimeout_Success tests a handler that completes within the allotted timeout.
func TestTimeout_Success(t *testing.T) {
t.Parallel()
app := fiber.New()
// Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit.
app.Get("/fast", New(func(c fiber.Ctx) error {
// Simulate some work
if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("OK")
}, 50*time.Millisecond))
req := httptest.NewRequest(fiber.MethodGet, "/fast", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests")
}
// TestTimeout_Exceeded tests a handler that exceeds the provided timeout.
func TestTimeout_Exceeded(t *testing.T) {
t.Parallel()
app := fiber.New()
// This handler sleeps 200ms, exceeding the 100ms limit.
app.Get("/slow", New(func(c fiber.Ctx) error {
if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("Should never get here")
}, 100*time.Millisecond))
req := httptest.NewRequest(fiber.MethodGet, "/slow", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout")
}
// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout.
func TestTimeout_CustomError(t *testing.T) {
t.Parallel()
app := fiber.New()
// This handler sleeps 50ms and returns errCustomTimeout if canceled.
app.Get("/custom", New(func(c fiber.Ctx) error {
// Sleep might time out, or might return early. If the context is canceled,
// we treat errCustomTimeout as a 'timeout-like' condition.
if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil {
return fmt.Errorf("wrapped: %w", err)
}
return c.SendString("Should never get here")
}, 100*time.Millisecond, errCustomTimeout))
req := httptest.NewRequest(fiber.MethodGet, "/custom", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error")
}
// TestTimeout_UnmatchedError checks that if the handler returns an error
// that is neither a deadline exceeded nor a custom 'timeout' error, it is
// propagated as a regular 500 (internal server error).
func TestTimeout_UnmatchedError(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/unmatched", New(func(_ fiber.Ctx) error {
return errUnrelated // Not in the custom error list
}, 100*time.Millisecond, errCustomTimeout))
req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode,
"Expected 500 because the error is not recognized as a timeout error")
}
// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero.
// Usually this means the request can never exceed a 'deadline' effectively no timeout.
func TestTimeout_ZeroDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/zero", New(func(c fiber.Ctx) error {
// Sleep 50ms, but there's no real 'deadline' since zero-timeout.
time.Sleep(50 * time.Millisecond)
return c.SendString("No timeout used")
}, 0))
req := httptest.NewRequest(fiber.MethodGet, "/zero", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
}