Support for sub fiber's error handlers (#1560)

- Mounted fiber and its sub apps error handlers are now saved a new
  errorHandlers map in App
- New public App.ErrorHandler method that wraps the logic for which
  error handler to user on any given context
- Error handler match logic based on request path <=> prefix accuracy
- Typo fixes
- Tests
pull/1564/head
Jose Garcia 2021-10-05 08:03:20 -04:00 committed by GitHub
parent 2c6ffb7972
commit 587f3ae9df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 155 additions and 28 deletions

102
app.go
View File

@ -109,6 +109,8 @@ type App struct {
getBytes func(s string) (b []byte)
// Converts byte slice to a string
getString func(b []byte) string
// mount prefix -> error handler
errorHandlers map[string]ErrorHandler
}
// Config is a struct holding the server settings.
@ -426,9 +428,10 @@ func New(config ...Config) *App {
},
},
// Create config
config: Config{},
getBytes: utils.UnsafeBytes,
getString: utils.UnsafeString,
config: Config{},
getBytes: utils.UnsafeBytes,
getString: utils.UnsafeString,
errorHandlers: make(map[string]ErrorHandler),
}
// Override config if provided
if len(config) > 0 {
@ -460,9 +463,11 @@ func New(config ...Config) *App {
if app.config.Immutable {
app.getBytes, app.getString = getBytesImmutable, getStringImmutable
}
if app.config.ErrorHandler == nil {
app.config.ErrorHandler = DefaultErrorHandler
}
if app.config.JSONEncoder == nil {
app.config.JSONEncoder = json.Marshal
}
@ -487,7 +492,9 @@ func New(config ...Config) *App {
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
// compose them as a single service using Mount. The fiber's error handler and
// any of the fiber's sub apps are added to the application's error handlers
// to be invoked on errors that happen within the prefix route.
func (app *App) Mount(prefix string, fiber *App) Router {
stack := fiber.Stack()
for m := range stack {
@ -497,6 +504,15 @@ func (app *App) Mount(prefix string, fiber *App) Router {
}
}
// Save the fiber's error handler and its sub apps
prefix = strings.TrimRight(prefix, "/")
if fiber.config.ErrorHandler != nil {
app.errorHandlers[prefix] = fiber.config.ErrorHandler
}
for mountedPrefixes, errHandler := range fiber.errorHandlers {
app.errorHandlers[prefix+mountedPrefixes] = errHandler
}
atomic.AddUint32(&app.handlerCount, fiber.handlerCount)
return app
@ -822,7 +838,7 @@ func (app *App) init() *App {
// lock application
app.mutex.Lock()
// Only load templates if an view engine is specified
// Only load templates if a view engine is specified
if app.config.Views != nil {
if err := app.config.Views.Load(); err != nil {
fmt.Printf("views: %v\n", err)
@ -833,26 +849,7 @@ func (app *App) init() *App {
app.server = &fasthttp.Server{
Logger: &disableLogger{},
LogAllErrors: false,
ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
err = ErrRequestTimeout
} else {
err = ErrBadRequest
}
if catch := app.config.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
app.ReleaseCtx(c)
},
ErrorHandler: app.serverErrorHandler,
}
// fasthttp server settings
@ -880,6 +877,60 @@ func (app *App) init() *App {
return app
}
// ErrorHandler is the application's method in charge of finding the
// appropiate handler for the given request. It searches any mounted
// sub fibers by their prefixes and if it finds a match, it uses that
// error handler. Otherwise it uses the configured error handler for
// the app, which if not set is the DefaultErrorHandler.
func (app *App) ErrorHandler(ctx *Ctx, err error) error {
var (
mountedErrHandler ErrorHandler
mountedPrefixParts int
)
for prefix, errHandler := range app.errorHandlers {
if strings.HasPrefix(ctx.path, prefix) {
parts := len(strings.Split(prefix, "/"))
if mountedPrefixParts <= parts {
mountedErrHandler = errHandler
mountedPrefixParts = parts
}
}
}
if mountedErrHandler != nil {
return mountedErrHandler(ctx, err)
}
return app.config.ErrorHandler(ctx, err)
}
// serverErrorHandler is a wrapper around the application's error handler method
// user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber
// errors before calling the application's error handler method.
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
err = ErrRequestTimeout
} else {
err = ErrBadRequest
}
if catch := app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
app.ReleaseCtx(c)
}
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
func (app *App) startupProcess() *App {
app.mutex.Lock()
@ -961,7 +1012,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
}
scheme := "http"
if tls {
scheme = "https"

View File

@ -1439,3 +1439,80 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
utils.AssertEqual(t, testString, string(body))
}
func Test_App_UseMountedErrorHandler(t *testing.T) {
app := New()
fiber := New(Config{
ErrorHandler: func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom error")
},
})
fiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
app.Mount("/api", fiber)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
b, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body")
}
func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) {
app := New()
tsf := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom sub sub fiber error")
}
tripleSubFiber := New(Config{
ErrorHandler: tsf,
})
tripleSubFiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
sf := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom sub fiber error")
}
subfiber := New(Config{
ErrorHandler: sf,
})
subfiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
subfiber.Mount("/third", tripleSubFiber)
f := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom error")
}
fiber := New(Config{
ErrorHandler: f,
})
fiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
fiber.Mount("/sub", subfiber)
app.Mount("/api", fiber)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil))
utils.AssertEqual(t, nil, err, "/api/sub req")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
b, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body")
resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil))
utils.AssertEqual(t, nil, err, "/api/sub/third req")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
b, err = ioutil.ReadAll(resp2.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom sub sub fiber error", string(b), "Third fiber Response body")
}

View File

@ -145,7 +145,7 @@ func New(config ...Config) fiber.Handler {
}
}
// override error handler
errHandler = c.App().Config().ErrorHandler
errHandler = c.App().ErrorHandler
})
var start, stop time.Time

View File

@ -154,7 +154,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) {
// Find match in stack
match, err := app.next(c)
if err != nil {
if catch := c.app.config.ErrorHandler(c, err); catch != nil {
if catch := c.app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
}