mirror of https://github.com/gofiber/fiber.git
🔥 Update: add timeout context middleware (#2090)
* 🔥 Feature: add timeoutcontext middleware
* move timeoutconext to timeout package
* remove timeoutcontext readme.md
* replace timeout mware with timeout context mware
* Update README.md
* Update README.md
* update timeout middleware readme
* test curl commands fixed
* rename sample code title on timeout middleware
Co-authored-by: RW <rene@gofiber.io>
pull/2099/head
parent
e829caf808
commit
7c83e38757
|
@ -1,5 +1,9 @@
|
|||
# Timeout
|
||||
Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.Handler` with a timeout. If the handler takes longer than the given duration to return, the timeout error is set and forwarded to the centralized [ErrorHandler](https://docs.gofiber.io/error-handling).
|
||||
Timeout middleware for Fiber. As a `fiber.Handler` wrapper, it creates a context with `context.WithTimeout` and pass it in `UserContext`.
|
||||
|
||||
If the context passed executions (eg. DB ops, Http calls) takes longer than the given duration to return, the timeout error is set and forwarded to the centralized `ErrorHandler`.
|
||||
|
||||
It has no race conditions, ready to use on production.
|
||||
|
||||
### Table of Contents
|
||||
- [Signatures](#signatures)
|
||||
|
@ -8,7 +12,7 @@ Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.
|
|||
|
||||
### Signatures
|
||||
```go
|
||||
func New(h fiber.Handler, t time.Duration) fiber.Handler
|
||||
func New(handler fiber.Handler, timeout time.Duration, timeoutErrors ...error) fiber.Handler
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
@ -20,15 +24,76 @@ import (
|
|||
)
|
||||
```
|
||||
|
||||
After you initiate your Fiber app, you can use the following possibilities:
|
||||
Sample timeout middleware usage
|
||||
```go
|
||||
handler := func(ctx *fiber.Ctx) error {
|
||||
err := ctx.SendString("Hello, World 👋!")
|
||||
if err != nil {
|
||||
return err
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
h := func(c *fiber.Ctx) error {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
if err := sleepWithContext(c.UserContext(), sleepTime); err != nil {
|
||||
return fmt.Errorf("%w: execution error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second))
|
||||
_ = app.Listen(":3000")
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
timer := time.NewTimer(d)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return context.DeadlineExceeded
|
||||
case <-timer.C:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
Test http 200 with curl:
|
||||
```bash
|
||||
curl --location -I --request GET 'http://localhost:3000/foo/1000'
|
||||
```
|
||||
|
||||
Test http 408 with curl:
|
||||
```bash
|
||||
curl --location -I --request GET 'http://localhost:3000/foo/3000'
|
||||
```
|
||||
|
||||
|
||||
When using with custom error:
|
||||
```go
|
||||
var ErrFooTimeOut = errors.New("foo context canceled")
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
h := func(c *fiber.Ctx) error {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
if err := sleepWithContextWithCustomError(c.UserContext(), sleepTime); err != nil {
|
||||
return fmt.Errorf("%w: execution error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second), ErrFooTimeOut)
|
||||
_ = app.Listen(":3000")
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
timer := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ErrFooTimeOut
|
||||
case <-timer.C:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Get("/foo", timeout.New(handler, 5 * time.Second))
|
||||
```
|
||||
|
|
|
@ -1,43 +1,30 @@
|
|||
package timeout
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var once sync.Once
|
||||
|
||||
// New wraps a handler and aborts the process of the handler if the timeout is reached
|
||||
func New(handler fiber.Handler, timeout time.Duration) fiber.Handler {
|
||||
once.Do(func() {
|
||||
fmt.Println("[Warning] timeout contains data race issues, not ready for production!")
|
||||
})
|
||||
|
||||
if timeout <= 0 {
|
||||
return handler
|
||||
}
|
||||
|
||||
// logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418
|
||||
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
|
||||
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
ch := make(chan struct{}, 1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
_ = handler(ctx)
|
||||
ch <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(timeout):
|
||||
return fiber.ErrRequestTimeout
|
||||
timeoutContext, cancel := context.WithTimeout(ctx.UserContext(), t)
|
||||
defer cancel()
|
||||
ctx.SetUserContext(timeoutContext)
|
||||
if err := h(ctx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return fiber.ErrRequestTimeout
|
||||
}
|
||||
for i := range tErrs {
|
||||
if errors.Is(err, tErrs[i]) {
|
||||
return fiber.ErrRequestTimeout
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,55 +1,84 @@
|
|||
package timeout
|
||||
|
||||
// // go test -run Test_Middleware_Timeout
|
||||
// func Test_Middleware_Timeout(t *testing.T) {
|
||||
// app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// h := New(func(c *fiber.Ctx) error {
|
||||
// sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
// time.Sleep(sleepTime)
|
||||
// return c.SendString("After " + c.Params("sleepTime") + "ms sleeping")
|
||||
// }, 5*time.Millisecond)
|
||||
// app.Get("/test/:sleepTime", h)
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// testTimeout := func(timeoutStr string) {
|
||||
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
// utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
// go test -run Test_Timeout
|
||||
func Test_Timeout(t *testing.T) {
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
h := New(func(c *fiber.Ctx) error {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
|
||||
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
|
||||
}
|
||||
return nil
|
||||
}, 100*time.Millisecond)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
testTimeout("300")
|
||||
testTimeout("500")
|
||||
testSucces("50")
|
||||
testSucces("30")
|
||||
}
|
||||
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, "Request Timeout", string(body))
|
||||
// }
|
||||
// testSucces := func(timeoutStr string) {
|
||||
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
// utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
var ErrFooTimeOut = errors.New("foo context canceled")
|
||||
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body))
|
||||
// }
|
||||
// go test -run Test_TimeoutWithCustomError
|
||||
func Test_TimeoutWithCustomError(t *testing.T) {
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
h := New(func(c *fiber.Ctx) error {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
|
||||
return fmt.Errorf("%w: execution error", err)
|
||||
}
|
||||
return nil
|
||||
}, 100*time.Millisecond, ErrFooTimeOut)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
}
|
||||
testTimeout("300")
|
||||
testTimeout("500")
|
||||
testSucces("50")
|
||||
testSucces("30")
|
||||
}
|
||||
|
||||
// testTimeout("15")
|
||||
// testSucces("2")
|
||||
// testTimeout("30")
|
||||
// testSucces("3")
|
||||
// }
|
||||
|
||||
// // go test -run -v Test_Timeout_Panic
|
||||
// func Test_Timeout_Panic(t *testing.T) {
|
||||
// app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
// app.Get("/panic", recover.New(), New(func(c *fiber.Ctx) error {
|
||||
// c.Set("dummy", "this should not be here")
|
||||
// panic("panic in timeout handler")
|
||||
// }, 5*time.Millisecond))
|
||||
|
||||
// resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
|
||||
// utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, "Request Timeout", string(body))
|
||||
// }
|
||||
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
|
||||
timer := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return te
|
||||
case <-timer.C:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue