mirror of https://github.com/gofiber/fiber.git
Support for sub fiber's error handlers (#1560)
- Mounted fiber and its sub apps error handlers are now saved a new errorHandlers map in App - New public App.ErrorHandler method that wraps the logic for which error handler to user on any given context - Error handler match logic based on request path <=> prefix accuracy - Typo fixes - Testspull/1564/head
parent
2c6ffb7972
commit
587f3ae9df
102
app.go
102
app.go
|
@ -109,6 +109,8 @@ type App struct {
|
||||||
getBytes func(s string) (b []byte)
|
getBytes func(s string) (b []byte)
|
||||||
// Converts byte slice to a string
|
// Converts byte slice to a string
|
||||||
getString func(b []byte) string
|
getString func(b []byte) string
|
||||||
|
// mount prefix -> error handler
|
||||||
|
errorHandlers map[string]ErrorHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config is a struct holding the server settings.
|
// Config is a struct holding the server settings.
|
||||||
|
@ -426,9 +428,10 @@ func New(config ...Config) *App {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// Create config
|
// Create config
|
||||||
config: Config{},
|
config: Config{},
|
||||||
getBytes: utils.UnsafeBytes,
|
getBytes: utils.UnsafeBytes,
|
||||||
getString: utils.UnsafeString,
|
getString: utils.UnsafeString,
|
||||||
|
errorHandlers: make(map[string]ErrorHandler),
|
||||||
}
|
}
|
||||||
// Override config if provided
|
// Override config if provided
|
||||||
if len(config) > 0 {
|
if len(config) > 0 {
|
||||||
|
@ -460,9 +463,11 @@ func New(config ...Config) *App {
|
||||||
if app.config.Immutable {
|
if app.config.Immutable {
|
||||||
app.getBytes, app.getString = getBytesImmutable, getStringImmutable
|
app.getBytes, app.getString = getBytesImmutable, getStringImmutable
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.config.ErrorHandler == nil {
|
if app.config.ErrorHandler == nil {
|
||||||
app.config.ErrorHandler = DefaultErrorHandler
|
app.config.ErrorHandler = DefaultErrorHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.config.JSONEncoder == nil {
|
if app.config.JSONEncoder == nil {
|
||||||
app.config.JSONEncoder = json.Marshal
|
app.config.JSONEncoder = json.Marshal
|
||||||
}
|
}
|
||||||
|
@ -487,7 +492,9 @@ func New(config ...Config) *App {
|
||||||
|
|
||||||
// Mount attaches another app instance as a sub-router along a routing path.
|
// Mount attaches another app instance as a sub-router along a routing path.
|
||||||
// It's very useful to split up a large API as many independent routers and
|
// It's very useful to split up a large API as many independent routers and
|
||||||
// compose them as a single service using Mount.
|
// compose them as a single service using Mount. The fiber's error handler and
|
||||||
|
// any of the fiber's sub apps are added to the application's error handlers
|
||||||
|
// to be invoked on errors that happen within the prefix route.
|
||||||
func (app *App) Mount(prefix string, fiber *App) Router {
|
func (app *App) Mount(prefix string, fiber *App) Router {
|
||||||
stack := fiber.Stack()
|
stack := fiber.Stack()
|
||||||
for m := range stack {
|
for m := range stack {
|
||||||
|
@ -497,6 +504,15 @@ func (app *App) Mount(prefix string, fiber *App) Router {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save the fiber's error handler and its sub apps
|
||||||
|
prefix = strings.TrimRight(prefix, "/")
|
||||||
|
if fiber.config.ErrorHandler != nil {
|
||||||
|
app.errorHandlers[prefix] = fiber.config.ErrorHandler
|
||||||
|
}
|
||||||
|
for mountedPrefixes, errHandler := range fiber.errorHandlers {
|
||||||
|
app.errorHandlers[prefix+mountedPrefixes] = errHandler
|
||||||
|
}
|
||||||
|
|
||||||
atomic.AddUint32(&app.handlerCount, fiber.handlerCount)
|
atomic.AddUint32(&app.handlerCount, fiber.handlerCount)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
@ -822,7 +838,7 @@ func (app *App) init() *App {
|
||||||
// lock application
|
// lock application
|
||||||
app.mutex.Lock()
|
app.mutex.Lock()
|
||||||
|
|
||||||
// Only load templates if an view engine is specified
|
// Only load templates if a view engine is specified
|
||||||
if app.config.Views != nil {
|
if app.config.Views != nil {
|
||||||
if err := app.config.Views.Load(); err != nil {
|
if err := app.config.Views.Load(); err != nil {
|
||||||
fmt.Printf("views: %v\n", err)
|
fmt.Printf("views: %v\n", err)
|
||||||
|
@ -833,26 +849,7 @@ func (app *App) init() *App {
|
||||||
app.server = &fasthttp.Server{
|
app.server = &fasthttp.Server{
|
||||||
Logger: &disableLogger{},
|
Logger: &disableLogger{},
|
||||||
LogAllErrors: false,
|
LogAllErrors: false,
|
||||||
ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) {
|
ErrorHandler: app.serverErrorHandler,
|
||||||
c := app.AcquireCtx(fctx)
|
|
||||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
|
||||||
err = ErrRequestHeaderFieldsTooLarge
|
|
||||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
|
||||||
err = ErrRequestTimeout
|
|
||||||
} else if err == fasthttp.ErrBodyTooLarge {
|
|
||||||
err = ErrRequestEntityTooLarge
|
|
||||||
} else if err == fasthttp.ErrGetOnly {
|
|
||||||
err = ErrMethodNotAllowed
|
|
||||||
} else if strings.Contains(err.Error(), "timeout") {
|
|
||||||
err = ErrRequestTimeout
|
|
||||||
} else {
|
|
||||||
err = ErrBadRequest
|
|
||||||
}
|
|
||||||
if catch := app.config.ErrorHandler(c, err); catch != nil {
|
|
||||||
_ = c.SendStatus(StatusInternalServerError)
|
|
||||||
}
|
|
||||||
app.ReleaseCtx(c)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fasthttp server settings
|
// fasthttp server settings
|
||||||
|
@ -880,6 +877,60 @@ func (app *App) init() *App {
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorHandler is the application's method in charge of finding the
|
||||||
|
// appropiate handler for the given request. It searches any mounted
|
||||||
|
// sub fibers by their prefixes and if it finds a match, it uses that
|
||||||
|
// error handler. Otherwise it uses the configured error handler for
|
||||||
|
// the app, which if not set is the DefaultErrorHandler.
|
||||||
|
func (app *App) ErrorHandler(ctx *Ctx, err error) error {
|
||||||
|
var (
|
||||||
|
mountedErrHandler ErrorHandler
|
||||||
|
mountedPrefixParts int
|
||||||
|
)
|
||||||
|
|
||||||
|
for prefix, errHandler := range app.errorHandlers {
|
||||||
|
if strings.HasPrefix(ctx.path, prefix) {
|
||||||
|
parts := len(strings.Split(prefix, "/"))
|
||||||
|
if mountedPrefixParts <= parts {
|
||||||
|
mountedErrHandler = errHandler
|
||||||
|
mountedPrefixParts = parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mountedErrHandler != nil {
|
||||||
|
return mountedErrHandler(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return app.config.ErrorHandler(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverErrorHandler is a wrapper around the application's error handler method
|
||||||
|
// user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber
|
||||||
|
// errors before calling the application's error handler method.
|
||||||
|
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
|
||||||
|
c := app.AcquireCtx(fctx)
|
||||||
|
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
||||||
|
err = ErrRequestHeaderFieldsTooLarge
|
||||||
|
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
||||||
|
err = ErrRequestTimeout
|
||||||
|
} else if err == fasthttp.ErrBodyTooLarge {
|
||||||
|
err = ErrRequestEntityTooLarge
|
||||||
|
} else if err == fasthttp.ErrGetOnly {
|
||||||
|
err = ErrMethodNotAllowed
|
||||||
|
} else if strings.Contains(err.Error(), "timeout") {
|
||||||
|
err = ErrRequestTimeout
|
||||||
|
} else {
|
||||||
|
err = ErrBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if catch := app.ErrorHandler(c, err); catch != nil {
|
||||||
|
_ = c.SendStatus(StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.ReleaseCtx(c)
|
||||||
|
}
|
||||||
|
|
||||||
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
|
// startupProcess Is the method which executes all the necessary processes just before the start of the server.
|
||||||
func (app *App) startupProcess() *App {
|
func (app *App) startupProcess() *App {
|
||||||
app.mutex.Lock()
|
app.mutex.Lock()
|
||||||
|
@ -961,7 +1012,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if tls {
|
if tls {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
77
app_test.go
77
app_test.go
|
@ -1439,3 +1439,80 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {
|
||||||
|
|
||||||
utils.AssertEqual(t, testString, string(body))
|
utils.AssertEqual(t, testString, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_App_UseMountedErrorHandler(t *testing.T) {
|
||||||
|
app := New()
|
||||||
|
|
||||||
|
fiber := New(Config{
|
||||||
|
ErrorHandler: func(ctx *Ctx, err error) error {
|
||||||
|
return ctx.Status(200).SendString("hi, i'm a custom error")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fiber.Get("/", func(c *Ctx) error {
|
||||||
|
return errors.New("something happened")
|
||||||
|
})
|
||||||
|
|
||||||
|
app.Mount("/api", fiber)
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil))
|
||||||
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
|
||||||
|
utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) {
|
||||||
|
app := New()
|
||||||
|
|
||||||
|
tsf := func(ctx *Ctx, err error) error {
|
||||||
|
return ctx.Status(200).SendString("hi, i'm a custom sub sub fiber error")
|
||||||
|
}
|
||||||
|
tripleSubFiber := New(Config{
|
||||||
|
ErrorHandler: tsf,
|
||||||
|
})
|
||||||
|
tripleSubFiber.Get("/", func(c *Ctx) error {
|
||||||
|
return errors.New("something happened")
|
||||||
|
})
|
||||||
|
|
||||||
|
sf := func(ctx *Ctx, err error) error {
|
||||||
|
return ctx.Status(200).SendString("hi, i'm a custom sub fiber error")
|
||||||
|
}
|
||||||
|
subfiber := New(Config{
|
||||||
|
ErrorHandler: sf,
|
||||||
|
})
|
||||||
|
subfiber.Get("/", func(c *Ctx) error {
|
||||||
|
return errors.New("something happened")
|
||||||
|
})
|
||||||
|
subfiber.Mount("/third", tripleSubFiber)
|
||||||
|
|
||||||
|
f := func(ctx *Ctx, err error) error {
|
||||||
|
return ctx.Status(200).SendString("hi, i'm a custom error")
|
||||||
|
}
|
||||||
|
fiber := New(Config{
|
||||||
|
ErrorHandler: f,
|
||||||
|
})
|
||||||
|
fiber.Get("/", func(c *Ctx) error {
|
||||||
|
return errors.New("something happened")
|
||||||
|
})
|
||||||
|
fiber.Mount("/sub", subfiber)
|
||||||
|
|
||||||
|
app.Mount("/api", fiber)
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil))
|
||||||
|
utils.AssertEqual(t, nil, err, "/api/sub req")
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
|
||||||
|
utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body")
|
||||||
|
|
||||||
|
resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil))
|
||||||
|
utils.AssertEqual(t, nil, err, "/api/sub/third req")
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
b, err = ioutil.ReadAll(resp2.Body)
|
||||||
|
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
|
||||||
|
utils.AssertEqual(t, "hi, i'm a custom sub sub fiber error", string(b), "Third fiber Response body")
|
||||||
|
}
|
||||||
|
|
|
@ -145,7 +145,7 @@ func New(config ...Config) fiber.Handler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// override error handler
|
// override error handler
|
||||||
errHandler = c.App().Config().ErrorHandler
|
errHandler = c.App().ErrorHandler
|
||||||
})
|
})
|
||||||
|
|
||||||
var start, stop time.Time
|
var start, stop time.Time
|
||||||
|
|
|
@ -154,7 +154,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) {
|
||||||
// Find match in stack
|
// Find match in stack
|
||||||
match, err := app.next(c)
|
match, err := app.next(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if catch := c.app.config.ErrorHandler(c, err); catch != nil {
|
if catch := c.app.ErrorHandler(c, err); catch != nil {
|
||||||
_ = c.SendStatus(StatusInternalServerError)
|
_ = c.SendStatus(StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue