diff --git a/app.go b/app.go index 67ad99db..f21f0452 100644 --- a/app.go +++ b/app.go @@ -52,11 +52,14 @@ type App struct { // Settings holds is a struct holding the server settings type Settings struct { - // Possible feature for v1.11.x // ErrorHandler is executed when you pass an error in the Next(err) method // This function is also executed when middleware.Recover() catches a panic - // Default: func(ctx *fiber.Ctx, err error) { - // ctx.Status(fiber.StatusBadRequest).SendString(err.Error()) + // Default: func(ctx *Ctx, err error) { + // code := StatusInternalServerError + // if e, ok := err.(*Error); ok { + // code = e.Code + // } + // ctx.Status(code).SendString(err.Error()) // } ErrorHandler func(*Ctx, error) @@ -156,6 +159,26 @@ type Static struct { Index string } +// Error represents an error that occurred while handling a request. +type Error struct { + Code int + Message string +} + +// Error makes it compatible with `error` interface. +func (e *Error) Error() string { + return e.Message +} + +// NewError creates a new HTTPError instance. +func NewError(code int, message ...string) *Error { + e := &Error{code, utils.StatusMessage(code)} + if len(message) > 0 { + e.Message = message[0] + } + return e +} + // Routes returns all registered routes // // for _, r := range app.Routes() { @@ -197,12 +220,16 @@ func New(settings ...*Settings) *App { Prefork: utils.GetArgument("-prefork"), BodyLimit: 4 * 1024 * 1024, Concurrency: 256 * 1024, - // Possible feature for v1.11.x ErrorHandler: func(ctx *Ctx, err error) { - ctx.Status(StatusInternalServerError).SendString(err.Error()) + code := StatusInternalServerError + if e, ok := err.(*Error); ok { + code = e.Code + } + ctx.Status(code).SendString(err.Error()) }, }, } + // Overwrite settings if provided if len(settings) > 0 { app.Settings = settings[0] @@ -220,10 +247,14 @@ func New(settings ...*Settings) *App { getBytes = getBytesImmutable getString = getStringImmutable } - // Possible feature for v1.11.x + // Set default error if app.Settings.ErrorHandler == nil { app.Settings.ErrorHandler = func(ctx *Ctx, err error) { - ctx.Status(StatusInternalServerError).SendString(err.Error()) + code := StatusInternalServerError + if e, ok := err.(*Error); ok { + code = e.Code + } + ctx.Status(code).SendString(err.Error()) } } } @@ -517,23 +548,18 @@ func (app *App) init() *App { Logger: &disableLogger{}, LogAllErrors: false, ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) { - // Possible feature for v1.11.x - // ctx := app.AcquireCtx(fctx) - // app.Settings.ErrorHandler(ctx, err) - // app.ReleaseCtx(ctx) + ctx := app.AcquireCtx(fctx) if _, ok := err.(*fasthttp.ErrSmallBuffer); ok { - fctx.Response.SetStatusCode(StatusRequestHeaderFieldsTooLarge) - fctx.Response.SetBodyString("Request Header Fields Too Large") + ctx.err = ErrRequestHeaderFieldsTooLarge } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { - fctx.Response.SetStatusCode(StatusRequestTimeout) - fctx.Response.SetBodyString("Request Timeout") + ctx.err = ErrRequestTimeout } else if len(err.Error()) == 33 && err.Error() == "body size exceeds the given limit" { - fctx.Response.SetStatusCode(StatusRequestEntityTooLarge) - fctx.Response.SetBodyString("Request Entity Too Large") + ctx.err = ErrRequestEntityTooLarge } else { - fctx.Response.SetStatusCode(StatusBadRequest) - fctx.Response.SetBodyString("Bad Request") + ctx.err = ErrBadRequest } + app.Settings.ErrorHandler(ctx, ctx.err) // ctx.Route() not available + app.ReleaseCtx(ctx) }, } } diff --git a/ctx.go b/ctx.go index 29e08ec0..2f5abcf0 100644 --- a/ctx.go +++ b/ctx.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "encoding/xml" + "errors" "fmt" "io" "io/ioutil" @@ -40,7 +41,7 @@ type Ctx struct { path string // Prettified HTTP path pathOriginal string // Original HTTP path values []string // Route parameter values - err error // Contains error if caught + err error // Contains error if passed to Next Fasthttp *fasthttp.RequestCtx // Reference to *fasthttp.RequestCtx } @@ -70,9 +71,6 @@ type Templates interface { Render(io.Writer, string, interface{}) error } -// Global variables -var cacheControlNoCacheRegexp, _ = regexp.Compile(`/(?:^|,)\s*?no-cache\s*?(?:,|$)/`) - // AcquireCtx from pool func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx { ctx := app.pool.Get().(*Ctx) @@ -325,6 +323,9 @@ func (ctx *Ctx) Download(file string, filename ...string) { // Error contains the error information passed via the Next(err) method. func (ctx *Ctx) Error() error { + if ctx.err == nil { + return errors.New("") + } return ctx.err } @@ -384,6 +385,9 @@ func (ctx *Ctx) FormValue(key string) (value string) { return getString(ctx.Fasthttp.FormValue(key)) } +// Global variables +var cacheControlNoCacheRegexp, _ = regexp.Compile(`/(?:^|,)\s*?no-cache\s*?(?:,|$)/`) + // Fresh When the response is still “fresh” in the client’s cache true is returned, // otherwise false is returned to indicate that the client cache is now stale // and the full response should be sent. @@ -587,13 +591,10 @@ func (ctx *Ctx) MultipartForm() (*multipart.Form, error) { // Next executes the next method in the stack that matches the current route. // You can pass an optional error for custom error handling. func (ctx *Ctx) Next(err ...error) { - if ctx.app == nil { - return - } if len(err) > 0 { - ctx.err = err[0] ctx.Fasthttp.Response.Header.Reset() - ctx.app.Settings.ErrorHandler(ctx, err[0]) + ctx.err = err[0] + ctx.app.Settings.ErrorHandler(ctx, ctx.err) return } @@ -637,7 +638,7 @@ func (ctx *Ctx) Params(key string) string { // Path returns the path part of the request URL. // Optionally, you could override the path. func (ctx *Ctx) Path(override ...string) string { - if len(override) != 0 && ctx.path != override[0] && ctx.app != nil { + if len(override) != 0 && ctx.path != override[0] { // Set new path to request ctx.Fasthttp.Request.URI().SetPath(override[0]) // Set new path to context @@ -833,7 +834,8 @@ func (ctx *Ctx) SendFile(file string, compress ...bool) { hasTrailingSlash := len(file) > 0 && file[len(file)-1] == '/' var err error if file, err = filepath.Abs(file); err != nil { - ctx.app.Settings.ErrorHandler(ctx, err) + ctx.err = err + ctx.app.Settings.ErrorHandler(ctx, ctx.err) return } if hasTrailingSlash { diff --git a/ctx_test.go b/ctx_test.go index 67265f04..8a1597a7 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -1024,6 +1024,20 @@ func Test_Ctx_Next(t *testing.T) { utils.AssertEqual(t, "Works", resp.Header.Get("X-Next-Result")) } +// go test -run Test_Ctx_Next_Error +func Test_Ctx_Next_Error(t *testing.T) { + app := New() + app.Use("/", func(c *Ctx) { + c.Set("X-Next-Result", "Works") + c.Next(ErrNotFound) + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "http://example.com/test", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, StatusNotFound, resp.StatusCode, "Status code") + utils.AssertEqual(t, "", resp.Header.Get("X-Next-Result")) +} + // go test -run Test_Ctx_Redirect func Test_Ctx_Redirect(t *testing.T) { t.Parallel() diff --git a/middleware/recover.go b/middleware/recover.go index ba686bef..2ea24b33 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -15,8 +15,7 @@ func Recover() fiber.Handler { if !ok { err = fmt.Errorf("%v", r) } - ctx.Fasthttp.Response.Header.Reset() - ctx.App().Settings.ErrorHandler(ctx, err) + ctx.Next(err) return } }() diff --git a/utils.go b/utils.go index 5cb10f40..55b59d65 100644 --- a/utils.go +++ b/utils.go @@ -451,6 +451,72 @@ const ( StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 ) +// Errors +var ( + ErrContinue = NewError(StatusContinue) // RFC 7231, 6.2.1 + ErrSwitchingProtocols = NewError(StatusSwitchingProtocols) // RFC 7231, 6.2.2 + ErrProcessing = NewError(StatusProcessing) // RFC 2518, 10.1 + ErrEarlyHints = NewError(StatusEarlyHints) // RFC 8297 + ErrOK = NewError(StatusOK) // RFC 7231, 6.3.1 + ErrCreated = NewError(StatusCreated) // RFC 7231, 6.3.2 + ErrAccepted = NewError(StatusAccepted) // RFC 7231, 6.3.3 + ErrNonAuthoritativeInfo = NewError(StatusNonAuthoritativeInfo) // RFC 7231, 6.3.4 + ErrNoContent = NewError(StatusNoContent) // RFC 7231, 6.3.5 + ErrResetContent = NewError(StatusResetContent) // RFC 7231, 6.3.6 + ErrPartialContent = NewError(StatusPartialContent) // RFC 7233, 4.1 + ErrMultiStatus = NewError(StatusMultiStatus) // RFC 4918, 11.1 + ErrAlreadyReported = NewError(StatusAlreadyReported) // RFC 5842, 7.1 + ErrIMUsed = NewError(StatusIMUsed) // RFC 3229, 10.4.1 + ErrMultipleChoices = NewError(StatusMultipleChoices) // RFC 7231, 6.4.1 + ErrMovedPermanently = NewError(StatusMovedPermanently) // RFC 7231, 6.4.2 + ErrFound = NewError(StatusFound) // RFC 7231, 6.4.3 + ErrSeeOther = NewError(StatusSeeOther) // RFC 7231, 6.4.4 + ErrNotModified = NewError(StatusNotModified) // RFC 7232, 4.1 + ErrUseProxy = NewError(StatusUseProxy) // RFC 7231, 6.4.5 + ErrTemporaryRedirect = NewError(StatusTemporaryRedirect) // RFC 7231, 6.4.7 + ErrPermanentRedirect = NewError(StatusPermanentRedirect) // RFC 7538, 3 + ErrBadRequest = NewError(StatusBadRequest) // RFC 7231, 6.5.1 + ErrUnauthorized = NewError(StatusUnauthorized) // RFC 7235, 3.1 + ErrPaymentRequired = NewError(StatusPaymentRequired) // RFC 7231, 6.5.2 + ErrForbidden = NewError(StatusForbidden) // RFC 7231, 6.5.3 + ErrNotFound = NewError(StatusNotFound) // RFC 7231, 6.5.4 + ErrMethodNotAllowed = NewError(StatusMethodNotAllowed) // RFC 7231, 6.5.5 + ErrNotAcceptable = NewError(StatusNotAcceptable) // RFC 7231, 6.5.6 + ErrProxyAuthRequired = NewError(StatusProxyAuthRequired) // RFC 7235, 3.2 + ErrRequestTimeout = NewError(StatusRequestTimeout) // RFC 7231, 6.5.7 + ErrConflict = NewError(StatusConflict) // RFC 7231, 6.5.8 + ErrGone = NewError(StatusGone) // RFC 7231, 6.5.9 + ErrLengthRequired = NewError(StatusLengthRequired) // RFC 7231, 6.5.10 + ErrPreconditionFailed = NewError(StatusPreconditionFailed) // RFC 7232, 4.2 + ErrRequestEntityTooLarge = NewError(StatusRequestEntityTooLarge) // RFC 7231, 6.5.11 + ErrRequestURITooLong = NewError(StatusRequestURITooLong) // RFC 7231, 6.5.12 + ErrUnsupportedMediaType = NewError(StatusUnsupportedMediaType) // RFC 7231, 6.5.13 + ErrRequestedRangeNotSatisfiable = NewError(StatusRequestedRangeNotSatisfiable) // RFC 7233, 4.4 + ErrExpectationFailed = NewError(StatusExpectationFailed) // RFC 7231, 6.5.14 + ErrTeapot = NewError(StatusTeapot) // RFC 7168, 2.3.3 + ErrMisdirectedRequest = NewError(StatusMisdirectedRequest) // RFC 7540, 9.1.2 + ErrUnprocessableEntity = NewError(StatusUnprocessableEntity) // RFC 4918, 11.2 + ErrLocked = NewError(StatusLocked) // RFC 4918, 11.3 + ErrFailedDependency = NewError(StatusFailedDependency) // RFC 4918, 11.4 + ErrTooEarly = NewError(StatusTooEarly) // RFC 8470, 5.2. + ErrUpgradeRequired = NewError(StatusUpgradeRequired) // RFC 7231, 6.5.15 + ErrPreconditionRequired = NewError(StatusPreconditionRequired) // RFC 6585, 3 + ErrTooManyRequests = NewError(StatusTooManyRequests) // RFC 6585, 4 + ErrRequestHeaderFieldsTooLarge = NewError(StatusRequestHeaderFieldsTooLarge) // RFC 6585, 5 + ErrUnavailableForLegalReasons = NewError(StatusUnavailableForLegalReasons) // RFC 7725, 3 + ErrInternalServerError = NewError(StatusInternalServerError) // RFC 7231, 6.6.1 + ErrNotImplemented = NewError(StatusNotImplemented) // RFC 7231, 6.6.2 + ErrBadGateway = NewError(StatusBadGateway) // RFC 7231, 6.6.3 + ErrServiceUnavailable = NewError(StatusServiceUnavailable) // RFC 7231, 6.6.4 + ErrGatewayTimeout = NewError(StatusGatewayTimeout) // RFC 7231, 6.6.5 + ErrHTTPVersionNotSupported = NewError(StatusHTTPVersionNotSupported) // RFC 7231, 6.6.6 + ErrVariantAlsoNegotiates = NewError(StatusVariantAlsoNegotiates) // RFC 2295, 8.1 + ErrInsufficientStorage = NewError(StatusInsufficientStorage) // RFC 4918, 11.5 + ErrLoopDetected = NewError(StatusLoopDetected) // RFC 5842, 7.2 + ErrNotExtended = NewError(StatusNotExtended) // RFC 2774, 7 + ErrNetworkAuthenticationRequired = NewError(StatusNetworkAuthenticationRequired) // RFC 6585, 6 +) + // HTTP Headers were copied from net/http. const ( HeaderAuthorization = "Authorization"