From 44d739496d5f5791a2afbed4c1fca6d83edfe8fc Mon Sep 17 00:00:00 2001 From: Fenny <25108519+Fenny@users.noreply.github.com> Date: Sat, 20 Jun 2020 17:26:48 +0200 Subject: [PATCH] Add 405 Method Not Allowed --- app.go | 4 ++-- app_test.go | 21 +++++++++++++++++++++ router.go | 23 ++++++++++++++++------- utils.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/app.go b/app.go index 1d191a84..df22109d 100644 --- a/app.go +++ b/app.go @@ -45,7 +45,7 @@ type Error struct { // App denotes the Fiber application. type App struct { mutex sync.Mutex - // Route stack + // Route stack divided by HTTP methods stack [][]*Route // Amount of registered routes routes int @@ -222,7 +222,7 @@ func New(settings ...*Settings) *App { // Create a new app app := &App{ // Create router stack - stack: make([][]*Route, len(methodINT)), + stack: make([][]*Route, len(methodINT)+1), // Create Ctx pool pool: sync.Pool{ New: func() interface{} { diff --git a/app_test.go b/app_test.go index 5e3a32d5..e9e4acaa 100644 --- a/app_test.go +++ b/app_test.go @@ -27,6 +27,27 @@ func testStatus200(t *testing.T, app *App, url string, method string) { utils.AssertEqual(t, 200, resp.StatusCode, "Status code") } +func Test_App_MethodNotAllowed(t *testing.T) { + app := New() + + app.Post("/", func(c *Ctx) {}) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 405, resp.StatusCode) + utils.AssertEqual(t, "POST", resp.Header.Get(HeaderAllow)) + + resp, err = app.Test(httptest.NewRequest("PATCH", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 405, resp.StatusCode) + utils.AssertEqual(t, "POST", resp.Header.Get(HeaderAllow)) + + resp, err = app.Test(httptest.NewRequest("PUT", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 405, resp.StatusCode) + utils.AssertEqual(t, "POST", resp.Header.Get(HeaderAllow)) +} + func Test_App_Routes(t *testing.T) { app := New() h := func(c *Ctx) {} diff --git a/router.go b/router.go index 1bcebb78..40acb9dc 100644 --- a/router.go +++ b/router.go @@ -37,13 +37,14 @@ type Router interface { // Route is a struct that holds all metadata for each registered handler type Route struct { // Data for routing - pos int // Position in stack - use bool // USE matches path prefixes - star bool // Path equals '*' - root bool // Path equals '/' - path string // Prettified path - routeParser routeParser // Parameter parser - routeParams []string // Case sensitive param keys + pos int // Position in stack + use bool // USE matches path prefixes + star bool // Path equals '*' + root bool // Path equals '/' + path string // Prettified path + allowedMethods string // Methods that are allowed on this route + routeParser routeParser // Parameter parser + routeParams []string // Case sensitive param keys // Public fields Path string // Original registered route path @@ -126,6 +127,10 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) { if match && app.Settings.ETag { setETag(ctx, false) } + // Scan stack for other methods + if !match { + setMethodNotAllowed(ctx) + } // Release Ctx app.ReleaseCtx(ctx) } @@ -313,9 +318,13 @@ func (app *App) registerStatic(prefix, root string, config ...Static) *Route { app.addRoute(MethodHead, route) return route } + func (app *App) addRoute(method string, route *Route) { // Get unique HTTP method indentifier m := methodINT[method] // Add route to the stack app.stack[m] = append(app.stack[m], route) + + // Add route to method allowed slice + app.stack[9] = append(app.stack[9], route) } diff --git a/utils.go b/utils.go index a2e70569..2338ad1c 100644 --- a/utils.go +++ b/utils.go @@ -15,6 +15,36 @@ import ( utils "github.com/gofiber/utils" ) +// Scan stack if other methods match +func setMethodNotAllowed(ctx *Ctx) { + original := getString(ctx.Fasthttp.Request.Header.Method()) + for m := range methodINT { + // Skip original method + if m == original { + continue + } + // Reset stack index + ctx.indexRoute = -1 + // Set new method + ctx.method = m + // Get stack length + lenr := len(ctx.app.stack[9]) - 1 + // Loop over the route stack starting from previous index + for ctx.indexRoute < lenr { + // Get *Route + route := ctx.app.stack[9][ctx.indexRoute] + // Check if it matches the request path + match, _ := route.match(ctx.path, ctx.pathOriginal) + // No match, next route + if match { + ctx.SendStatus(StatusMethodNotAllowed) + ctx.Vary(HeaderAllow, m) + break + } + } + } +} + // Generate and set ETag header to response func setETag(ctx *Ctx, weak bool) { // Don't generate ETags for invalid responses