mirror of https://github.com/gofiber/fiber.git
Merge remote-tracking branch 'upstream/master'
commit
4aafc504ad
5
app.go
5
app.go
|
@ -453,6 +453,11 @@ func (app *App) Listen(address interface{}, tlsconfig ...*tls.Config) error {
|
|||
return app.server.Serve(ln)
|
||||
}
|
||||
|
||||
// Handler returns the server handler
|
||||
func (app *App) Handler() fasthttp.RequestHandler {
|
||||
return app.handler
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server without interrupting any active connections.
|
||||
// Shutdown works by first closing all open listeners and then waiting indefinitely for all connections to return to idle and then shut down.
|
||||
//
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
fiber "github.com/gofiber/fiber"
|
||||
fasthttp "github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var concurrencyCh = make(chan struct{}, fasthttp.DefaultConcurrency)
|
||||
|
||||
// Timeout wraps a handler and aborts the process of the handler if the timeout is reached
|
||||
func Timeout(handler fiber.Handler, timeout time.Duration) fiber.Handler {
|
||||
if timeout <= 0 {
|
||||
return handler
|
||||
}
|
||||
|
||||
// logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418
|
||||
return func(ctx *fiber.Ctx) {
|
||||
select {
|
||||
case concurrencyCh <- struct{}{}:
|
||||
default:
|
||||
ctx.Next(fiber.ErrTooManyRequests)
|
||||
return
|
||||
}
|
||||
ch := make(chan struct{}, 1)
|
||||
|
||||
go func() {
|
||||
handler(ctx)
|
||||
ch <- struct{}{}
|
||||
<-concurrencyCh
|
||||
}()
|
||||
timeoutTimer := time.NewTimer(timeout)
|
||||
select {
|
||||
case <-ch:
|
||||
case <-timeoutTimer.C:
|
||||
ctx.Next(fiber.ErrRequestTimeout)
|
||||
}
|
||||
if !timeoutTimer.Stop() {
|
||||
// Collect possibly added time from the channel
|
||||
// if timer has been stopped and nobody collected its' value.
|
||||
select {
|
||||
case <-timeoutTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
# Timeout
|
||||
Wrapper function which provides a handler with a timeout.
|
||||
|
||||
If the handler takes longer than the given duration to return, the timeout error is set and forwarded to the error handler.
|
||||
### Example
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
"github.com/gofiber/fiber"
|
||||
"github.com/gofiber/fiber/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
|
||||
// wrap the handler with a timeout
|
||||
app.Get("/foo", middleware.Timeout(
|
||||
func(ctx fiber.Ctx) {
|
||||
// do somthing
|
||||
},
|
||||
5 * time.Second,
|
||||
))
|
||||
|
||||
app.Listen(3000)
|
||||
}
|
||||
```
|
||||
|
||||
### Signatures
|
||||
```go
|
||||
func Timeout(handler fiber.Handler, timeout time.Duration) fiber.Handler {}
|
||||
```
|
|
@ -0,0 +1,50 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
fiber "github.com/gofiber/fiber"
|
||||
utils "github.com/gofiber/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_Middleware_Timeout
|
||||
func Test_Middleware_Timeout(t *testing.T) {
|
||||
app := fiber.New(&fiber.Settings{DisableStartupMessage: true})
|
||||
|
||||
h := Timeout(
|
||||
func(c *fiber.Ctx) {
|
||||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
|
||||
time.Sleep(sleepTime)
|
||||
c.SendString("After " + c.Params("sleepTime") + "ms sleeping")
|
||||
},
|
||||
5*time.Millisecond,
|
||||
)
|
||||
app.Get("/test/:sleepTime", h)
|
||||
|
||||
testTimeout := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Request Timeout", string(body))
|
||||
}
|
||||
testSucces := func(timeoutStr string) {
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body))
|
||||
}
|
||||
|
||||
testTimeout("15")
|
||||
testSucces("2")
|
||||
testTimeout("30")
|
||||
testSucces("3")
|
||||
}
|
Loading…
Reference in New Issue