mirror of https://github.com/gofiber/fiber.git
refactor(timeout): add more test cases and handle zero duration case
parent
32f1e68a94
commit
803c5b9ab7
|
@ -12,6 +12,11 @@ import (
|
||||||
// any of the specified errors occur, fiber.ErrRequestTimeout is returned.
|
// any of the specified errors occur, fiber.ErrRequestTimeout is returned.
|
||||||
func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
|
func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
|
||||||
return func(ctx fiber.Ctx) error {
|
return func(ctx fiber.Ctx) error {
|
||||||
|
// If timeout <= 0, skip context.WithTimeout and run the handler as-is.
|
||||||
|
if timeout <= 0 {
|
||||||
|
return runHandler(ctx, h, tErrs)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a context with the specified timeout; any operation exceeding
|
// Create a context with the specified timeout; any operation exceeding
|
||||||
// this deadline will be canceled automatically.
|
// this deadline will be canceled automatically.
|
||||||
timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout)
|
timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout)
|
||||||
|
@ -39,6 +44,16 @@ func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runHandler executes the handler and returns fiber.ErrRequestTimeout if it
|
||||||
|
// sees a deadline exceeded error or one of the custom "timeout-like" errors.
|
||||||
|
func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error {
|
||||||
|
err := h(c)
|
||||||
|
if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) {
|
||||||
|
return fiber.ErrRequestTimeout
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// isCustomError checks whether err matches any error in errList using errors.Is.
|
// isCustomError checks whether err matches any error in errList using errors.Is.
|
||||||
func isCustomError(err error, errList []error) bool {
|
func isCustomError(err error, errList []error) bool {
|
||||||
for _, e := range errList {
|
for _, e := range errList {
|
||||||
|
|
|
@ -12,77 +12,114 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// go test -run Test_WithContextTimeout
|
var (
|
||||||
func Test_WithContextTimeout(t *testing.T) {
|
// Custom error that we treat like a timeout when returned by the handler.
|
||||||
t.Parallel()
|
errCustomTimeout = errors.New("custom timeout error")
|
||||||
// fiber instance
|
|
||||||
app := fiber.New()
|
|
||||||
h := New(func(c fiber.Ctx) error {
|
|
||||||
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
|
||||||
require.NoError(t, err)
|
|
||||||
if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil {
|
|
||||||
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}, 100*time.Millisecond)
|
|
||||||
app.Get("/test/:sleepTime", h)
|
|
||||||
testTimeout := func(timeoutStr string) {
|
|
||||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
|
||||||
require.NoError(t, err, "app.Test(req)")
|
|
||||||
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
|
||||||
}
|
|
||||||
testSucces := func(timeoutStr string) {
|
|
||||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
|
||||||
require.NoError(t, err, "app.Test(req)")
|
|
||||||
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
|
||||||
}
|
|
||||||
testTimeout("300")
|
|
||||||
testTimeout("500")
|
|
||||||
testSucces("50")
|
|
||||||
testSucces("30")
|
|
||||||
}
|
|
||||||
|
|
||||||
var ErrFooTimeOut = errors.New("foo context canceled")
|
// Some unrelated error that should NOT trigger a request timeout.
|
||||||
|
errUnrelated = errors.New("unmatched error")
|
||||||
// go test -run Test_WithContextTimeoutWithCustomError
|
)
|
||||||
func Test_WithContextTimeoutWithCustomError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
// fiber instance
|
|
||||||
app := fiber.New()
|
|
||||||
h := New(func(c fiber.Ctx) error {
|
|
||||||
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
|
|
||||||
require.NoError(t, err)
|
|
||||||
if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil {
|
|
||||||
return fmt.Errorf("%w: execution error", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}, 100*time.Millisecond, ErrFooTimeOut)
|
|
||||||
app.Get("/test/:sleepTime", h)
|
|
||||||
testTimeout := func(timeoutStr string) {
|
|
||||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
|
||||||
require.NoError(t, err, "app.Test(req)")
|
|
||||||
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
|
||||||
}
|
|
||||||
testSucces := func(timeoutStr string) {
|
|
||||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
|
|
||||||
require.NoError(t, err, "app.Test(req)")
|
|
||||||
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
|
||||||
}
|
|
||||||
testTimeout("300")
|
|
||||||
testTimeout("500")
|
|
||||||
testSucces("50")
|
|
||||||
testSucces("30")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled.
|
||||||
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
|
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
|
||||||
timer := time.NewTimer(d)
|
timer := time.NewTimer(d)
|
||||||
|
defer timer.Stop() // Clean up the timer
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
if !timer.Stop() {
|
|
||||||
<-timer.C
|
|
||||||
}
|
|
||||||
return te
|
return te
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimeout_Success tests a handler that completes within the allotted timeout.
|
||||||
|
func TestTimeout_Success(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
// Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit.
|
||||||
|
app.Get("/fast", New(func(c fiber.Ctx) error {
|
||||||
|
// Simulate some work
|
||||||
|
if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.SendString("OK")
|
||||||
|
}, 50*time.Millisecond))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(fiber.MethodGet, "/fast", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
require.NoError(t, err, "app.Test(req) should not fail")
|
||||||
|
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimeout_Exceeded tests a handler that exceeds the provided timeout.
|
||||||
|
func TestTimeout_Exceeded(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
// This handler sleeps 200ms, exceeding the 100ms limit.
|
||||||
|
app.Get("/slow", New(func(c fiber.Ctx) error {
|
||||||
|
if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.SendString("Should never get here")
|
||||||
|
}, 100*time.Millisecond))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(fiber.MethodGet, "/slow", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
require.NoError(t, err, "app.Test(req) should not fail")
|
||||||
|
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout.
|
||||||
|
func TestTimeout_CustomError(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
// This handler sleeps 50ms and returns errCustomTimeout if canceled.
|
||||||
|
app.Get("/custom", New(func(c fiber.Ctx) error {
|
||||||
|
// Sleep might time out, or might return early. If the context is canceled,
|
||||||
|
// we treat errCustomTimeout as a 'timeout-like' condition.
|
||||||
|
if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil {
|
||||||
|
return fmt.Errorf("wrapped: %w", err)
|
||||||
|
}
|
||||||
|
return c.SendString("Should never get here")
|
||||||
|
}, 100*time.Millisecond, errCustomTimeout))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(fiber.MethodGet, "/custom", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
require.NoError(t, err, "app.Test(req) should not fail")
|
||||||
|
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimeout_UnmatchedError checks that if the handler returns an error
|
||||||
|
// that is neither a deadline exceeded nor a custom 'timeout' error, it is
|
||||||
|
// propagated as a regular 500 (internal server error).
|
||||||
|
func TestTimeout_UnmatchedError(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Get("/unmatched", New(func(_ fiber.Ctx) error {
|
||||||
|
return errUnrelated // Not in the custom error list
|
||||||
|
}, 100*time.Millisecond, errCustomTimeout))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
require.NoError(t, err, "app.Test(req) should not fail")
|
||||||
|
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode,
|
||||||
|
"Expected 500 because the error is not recognized as a timeout error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero.
|
||||||
|
// Usually this means the request can never exceed a 'deadline' – effectively no timeout.
|
||||||
|
func TestTimeout_ZeroDuration(t *testing.T) {
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Get("/zero", New(func(c fiber.Ctx) error {
|
||||||
|
// Sleep 50ms, but there's no real 'deadline' since zero-timeout.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
return c.SendString("No timeout used")
|
||||||
|
}, 0))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(fiber.MethodGet, "/zero", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
require.NoError(t, err, "app.Test(req) should not fail")
|
||||||
|
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue