mirror of https://github.com/gofiber/fiber.git
Support for sub fiber's error handlers (#1560)
- Mounted fiber and its sub apps error handlers are now saved a new errorHandlers map in App - New public App.ErrorHandler method that wraps the logic for which error handler to user on any given context - Error handler match logic based on request path <=> prefix accuracy - Typo fixes - Testspull/1564/head
parent
2c6ffb7972
commit
587f3ae9df
102
app.go
102
app.go
|
@ -109,6 +109,8 @@ type App struct {
|
|||
getBytes func(s string) (b []byte)
|
||||
// Converts byte slice to a string
|
||||
getString func(b []byte) string
|
||||
// mount prefix -> error handler
|
||||
errorHandlers map[string]ErrorHandler
|
||||
}
|
||||
|
||||
// Config is a struct holding the server settings.
|
||||
|
@ -426,9 +428,10 @@ func New(config ...Config) *App {
|
|||
},
|
||||
},
|
||||
// Create config
|
||||
config: Config{},
|
||||
getBytes: utils.UnsafeBytes,
|
||||
getString: utils.UnsafeString,
|
||||
config: Config{},
|
||||
getBytes: utils.UnsafeBytes,
|
||||
getString: utils.UnsafeString,
|
||||
errorHandlers: make(map[string]ErrorHandler),
|
||||
}
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
|
@ -460,9 +463,11 @@ func New(config ...Config) *App {
|
|||
if app.config.Immutable {
|
||||
app.getBytes, app.getString = getBytesImmutable, getStringImmutable
|
||||
}
|
||||
|
||||
if app.config.ErrorHandler == nil {
|
||||
app.config.ErrorHandler = DefaultErrorHandler
|
||||
}
|
||||
|
||||
if app.config.JSONEncoder == nil {
|
||||
app.config.JSONEncoder = json.Marshal
|
||||
}
|
||||
|
@ -487,7 +492,9 @@ func New(config ...Config) *App {
|
|||
|
||||
// Mount attaches another app instance as a sub-router along a routing path.
|
||||
// It's very useful to split up a large API as many independent routers and
|
||||
// compose them as a single service using Mount.
|
||||
// compose them as a single service using Mount. The fiber's error handler and
|
||||
// any of the fiber's sub apps are added to the application's error handlers
|
||||
// to be invoked on errors that happen within the prefix route.
|
||||
func (app *App) Mount(prefix string, fiber *App) Router {
|
||||
stack := fiber.Stack()
|
||||
for m := range stack {
|
||||
|
@ -497,6 +504,15 @@ func (app *App) Mount(prefix string, fiber *App) Router {
|
|||
}
|
||||
}
|
||||
|
||||
// Save the fiber's error handler and its sub apps
|
||||
prefix = strings.TrimRight(prefix, "/")
|
||||
if fiber.config.ErrorHandler != nil {
|
||||
app.errorHandlers[prefix] = fiber.config.ErrorHandler
|
||||
}
|
||||
for mountedPrefixes, errHandler := range fiber.errorHandlers {
|
||||
app.errorHandlers[prefix+mountedPrefixes] = errHandler
|
||||
}
|
||||
|
||||
atomic.AddUint32(&app.handlerCount, fiber.handlerCount)
|
||||
|
||||
return app
|
||||
|
@ -822,7 +838,7 @@ func (app *App) init() *App {
|
|||
// lock application
|
||||
app.mutex.Lock()
|
||||
|
||||
// Only load templates if an view engine is specified
|
||||
// Only load templates if a view engine is specified
|
||||
if app.config.Views != nil {
|
||||
if err := app.config.Views.Load(); err != nil {
|
||||
fmt.Printf("views: %v\n", err)
|
||||
|
@ -833,26 +849,7 @@ func (app *App) init() *App {
|
|||
app.server = &fasthttp.Server{
|
||||
Logger: &disableLogger{},
|
||||
LogAllErrors: false,
|
||||
ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) {
|
||||
c := app.AcquireCtx(fctx)
|
||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
||||
err = ErrRequestHeaderFieldsTooLarge
|
||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
||||
err = ErrRequestTimeout
|
||||
} else if err == fasthttp.ErrBodyTooLarge {
|
||||
err = ErrRequestEntityTooLarge
|
||||
} else if err == fasthttp.ErrGetOnly {
|
||||
err = ErrMethodNotAllowed
|
||||
} else if strings.Contains(err.Error(), "timeout") {
|
||||
err = ErrRequestTimeout
|
||||
} else {
|
||||
err = ErrBadRequest
|
||||
}
|
||||
if catch := app.config.ErrorHandler(c, err); catch != nil {
|
||||
_ = c.SendStatus(StatusInternalServerError)
|
||||
}
|
||||
app.ReleaseCtx(c)
|
||||
},
|
||||
ErrorHandler: app.serverErrorHandler,
|
||||
}
|
||||
|
||||
// fasthttp server settings
|
||||
|
@ -880,6 +877,60 @@ func (app *App) init() *App {
|
|||
return app
|
||||
}
|
||||
|
||||
// ErrorHandler is the application's method in charge of finding the
|
||||
// appropiate handler for the given request. It searches any mounted
|
||||
// sub fibers by their prefixes and if it finds a match, it uses that
|
||||
// error handler. Otherwise it uses the configured error handler for
|
||||
// the app, which if not set is the DefaultErrorHandler.
|
||||
func (app *App) ErrorHandler(ctx *Ctx, err error) error {
|
||||
var (
|
||||
mountedErrHandler ErrorHandler
|
||||
mountedPrefixParts int
|
||||
)
|
||||
|
||||
for prefix, errHandler := range app.errorHandlers {
|
||||
if strings.HasPrefix(ctx.path, prefix) {
|
||||
parts := len(strings.Split(prefix, "/"))
|
||||
if mountedPrefixParts <= parts {
|
||||
mountedErrHandler = errHandler
|
||||
mountedPrefixParts = parts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mountedErrHandler != nil {
|
||||
return mountedErrHandler(ctx, err)
|
||||
}
|
||||
|
||||
return app.config.ErrorHandler(ctx, err)
|
||||
}
|
||||
|
||||
// serverErrorHandler is a wrapper around the application's error handler method
|
||||
// user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber
|
||||
// errors before calling the application's error handler method.
|
||||
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
|
||||
c := app.AcquireCtx(fctx)
|
||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
||||
err = ErrRequestHeaderFieldsTooLarge
|
||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
||||
err = ErrRequestTimeout
|
||||
} else if err == fasthttp.ErrBodyTooLarge {
|
||||
err = ErrRequestEntityTooLarge
|
||||
} else if err == fasthttp.ErrGetOnly {
|
||||
err = ErrMethodNotAllowed
|
||||
} else if strings.Contains(err.Error(), "timeout") {
|
||||
err = ErrRequestTimeout
|
||||
} else {
|
||||
err = ErrBadRequest
|
||||
}
|
||||
|
||||
if catch := app.ErrorHandler(c, err); catch != nil {
|
||||
_ = c.SendStatus(StatusInternalServerError)
|
||||
}
|
||||
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
|
||||
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
|
||||
func (app *App) startupProcess() *App {
|
||||
app.mutex.Lock()
|
||||
|
@ -961,7 +1012,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
scheme := "http"
|
||||
if tls {
|
||||
scheme = "https"
|
||||
|
|
77
app_test.go
77
app_test.go
|
@ -1439,3 +1439,80 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
|
|||
|
||||
utils.AssertEqual(t, testString, string(body))
|
||||
}
|
||||
|
||||
func Test_App_UseMountedErrorHandler(t *testing.T) {
|
||||
app := New()
|
||||
|
||||
fiber := New(Config{
|
||||
ErrorHandler: func(ctx *Ctx, err error) error {
|
||||
return ctx.Status(200).SendString("hi, i'm a custom error")
|
||||
},
|
||||
})
|
||||
fiber.Get("/", func(c *Ctx) error {
|
||||
return errors.New("something happened")
|
||||
})
|
||||
|
||||
app.Mount("/api", fiber)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
|
||||
utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body")
|
||||
}
|
||||
|
||||
func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) {
|
||||
app := New()
|
||||
|
||||
tsf := func(ctx *Ctx, err error) error {
|
||||
return ctx.Status(200).SendString("hi, i'm a custom sub sub fiber error")
|
||||
}
|
||||
tripleSubFiber := New(Config{
|
||||
ErrorHandler: tsf,
|
||||
})
|
||||
tripleSubFiber.Get("/", func(c *Ctx) error {
|
||||
return errors.New("something happened")
|
||||
})
|
||||
|
||||
sf := func(ctx *Ctx, err error) error {
|
||||
return ctx.Status(200).SendString("hi, i'm a custom sub fiber error")
|
||||
}
|
||||
subfiber := New(Config{
|
||||
ErrorHandler: sf,
|
||||
})
|
||||
subfiber.Get("/", func(c *Ctx) error {
|
||||
return errors.New("something happened")
|
||||
})
|
||||
subfiber.Mount("/third", tripleSubFiber)
|
||||
|
||||
f := func(ctx *Ctx, err error) error {
|
||||
return ctx.Status(200).SendString("hi, i'm a custom error")
|
||||
}
|
||||
fiber := New(Config{
|
||||
ErrorHandler: f,
|
||||
})
|
||||
fiber.Get("/", func(c *Ctx) error {
|
||||
return errors.New("something happened")
|
||||
})
|
||||
fiber.Mount("/sub", subfiber)
|
||||
|
||||
app.Mount("/api", fiber)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil))
|
||||
utils.AssertEqual(t, nil, err, "/api/sub req")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
|
||||
utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body")
|
||||
|
||||
resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil))
|
||||
utils.AssertEqual(t, nil, err, "/api/sub/third req")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
|
||||
b, err = ioutil.ReadAll(resp2.Body)
|
||||
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
|
||||
utils.AssertEqual(t, "hi, i'm a custom sub sub fiber error", string(b), "Third fiber Response body")
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ func New(config ...Config) fiber.Handler {
|
|||
}
|
||||
}
|
||||
// override error handler
|
||||
errHandler = c.App().Config().ErrorHandler
|
||||
errHandler = c.App().ErrorHandler
|
||||
})
|
||||
|
||||
var start, stop time.Time
|
||||
|
|
|
@ -154,7 +154,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) {
|
|||
// Find match in stack
|
||||
match, err := app.next(c)
|
||||
if err != nil {
|
||||
if catch := c.app.config.ErrorHandler(c, err); catch != nil {
|
||||
if catch := c.app.ErrorHandler(c, err); catch != nil {
|
||||
_ = c.SendStatus(StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue