From 94e30d7124cbf75bc2055cc49ec1ca7b8d7690e3 Mon Sep 17 00:00:00 2001 From: RW Date: Fri, 16 May 2025 08:29:39 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20routing=20with=20mount=20a?= =?UTF-8?q?nd=20static=20(#3454)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix routing with mount and static [Bug]: Static server in sub app does not work #3104 https://github.com/gofiber/fiber/issues/3104 [Bug]: When mounting a subapp with mount, the static route is inaccessible. #3442 https://github.com/gofiber/fiber/issues/3442 --- app.go | 2 -- mount.go | 8 ----- mount_test.go | 6 ---- router.go | 98 +++++++++++++++++++++++++++++--------------------- router_test.go | 48 ++++++++++++++++++++----- 5 files changed, 96 insertions(+), 66 deletions(-) diff --git a/app.go b/app.go index 751df9f2..3c1cc3cc 100644 --- a/app.go +++ b/app.go @@ -93,8 +93,6 @@ type App struct { treeStack []map[string][]*Route // contains the information if the route stack has been changed to build the optimized tree routesRefreshed bool - // Amount of registered routes - routesCount uint32 // Amount of registered handlers handlersCount uint32 // Ctx pool diff --git a/mount.go b/mount.go index abb5695e..b26eecc6 100644 --- a/mount.go +++ b/mount.go @@ -174,7 +174,6 @@ func (app *App) processSubAppsRoutes() { } } var handlersCount uint32 - var routePos uint32 // Iterate over the stack of the parent app for m := range app.stack { // Iterate over each route in the stack @@ -183,9 +182,6 @@ func (app *App) processSubAppsRoutes() { route := app.stack[m][i] // Check if the route has a mounted app if !route.mount { - routePos++ - // If not, update the route's position and continue - route.pos = routePos if !route.use || (route.use && m == 0) { handlersCount += uint32(len(route.Handlers)) } @@ -214,11 +210,7 @@ func (app *App) processSubAppsRoutes() { copy(newStack[i+len(subRoutes):], app.stack[m][i+1:]) app.stack[m] = newStack - // Decrease the parent app's route count to account for the mounted app's original route - atomic.AddUint32(&app.routesCount, ^uint32(0)) i-- - // Increase the parent app's route count to account for the sub-app's routes - atomic.AddUint32(&app.routesCount, uint32(len(subRoutes))) // Mark the parent app's routes as refreshed app.routesRefreshed = true diff --git a/mount_test.go b/mount_test.go index c0ca6bf3..63d4f1b3 100644 --- a/mount_test.go +++ b/mount_test.go @@ -89,7 +89,6 @@ func Test_App_Mount_Nested(t *testing.T) { utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, uint32(6), app.handlersCount) - utils.AssertEqual(t, uint32(6), app.routesCount) } // go test -run Test_App_Mount_Express_Behavior @@ -139,7 +138,6 @@ func Test_App_Mount_Express_Behavior(t *testing.T) { testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound) utils.AssertEqual(t, uint32(17), app.handlersCount) - utils.AssertEqual(t, uint32(16+9), app.routesCount) } // go test -run Test_App_Mount_RoutePositions @@ -195,19 +193,15 @@ func Test_App_Mount_RoutePositions(t *testing.T) { utils.AssertEqual(t, true, routeStackGET[1].use) utils.AssertEqual(t, "/", routeStackGET[1].path) - utils.AssertEqual(t, true, routeStackGET[0].pos < routeStackGET[1].pos, "wrong position of route 0") utils.AssertEqual(t, false, routeStackGET[2].use) utils.AssertEqual(t, "/bar", routeStackGET[2].path) - utils.AssertEqual(t, true, routeStackGET[1].pos < routeStackGET[2].pos, "wrong position of route 1") utils.AssertEqual(t, true, routeStackGET[3].use) utils.AssertEqual(t, "/", routeStackGET[3].path) - utils.AssertEqual(t, true, routeStackGET[2].pos < routeStackGET[3].pos, "wrong position of route 2") utils.AssertEqual(t, false, routeStackGET[4].use) utils.AssertEqual(t, "/subapp2/world", routeStackGET[4].path) - utils.AssertEqual(t, true, routeStackGET[3].pos < routeStackGET[4].pos, "wrong position of route 3") utils.AssertEqual(t, 5, len(routeStackGET)) } diff --git a/router.go b/router.go index 4afa7415..61ddc1bf 100644 --- a/router.go +++ b/router.go @@ -7,7 +7,6 @@ package fiber import ( "fmt" "html" - "sort" "strconv" "strings" "sync/atomic" @@ -47,9 +46,8 @@ type Router interface { // Route is a struct that holds all metadata for each registered handler. type Route struct { - // ### important: always keep in sync with the copy method "app.copyRoute" ### + // ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ### // Data for routing - pos uint32 // Position in stack -> important for the sort of the matched routes use bool // USE matches path prefixes mount bool // Indicated a mounted app on a specific route star bool // Path equals '*' @@ -215,9 +213,6 @@ func (*App) copyRoute(route *Route) *Route { path: route.path, routeParser: route.routeParser, - // misc - pos: route.pos, - // Public data Path: route.Path, Params: route.Params, @@ -298,11 +293,11 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl for _, m := range app.config.RequestMethods { // Create a route copy to avoid duplicates during compression r := route - app.addRoute(m, &r, isMount) + app.addRoute(m, &r) } } else { // Add route to stack - app.addRoute(method, &route, isMount) + app.addRoute(method, &route) } } @@ -428,12 +423,20 @@ func (app *App) registerStatic(prefix, root string, config ...Static) { // Create route metadata without pointer route := Route{ // Router booleans - use: true, - root: isRoot, + use: true, + mount: false, + star: isStar, + root: isRoot, + + // Path data path: prefix, + + // Group data + group: nil, + // Public data - Method: MethodGet, Path: prefix, + Method: MethodGet, Handlers: []Handler{handler}, } // Increment global handler count @@ -444,13 +447,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) { app.addRoute(MethodHead, &route) } -func (app *App) addRoute(method string, route *Route, isMounted ...bool) { - // Check mounted routes - var mounted bool - if len(isMounted) > 0 { - mounted = isMounted[0] - } - +func (app *App) addRoute(method string, route *Route) { // Get unique HTTP method identifier m := app.methodInt(method) @@ -460,8 +457,6 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) { preRoute := app.stack[m][l-1] preRoute.Handlers = append(preRoute.Handlers, route.Handlers...) } else { - // Increment global route position - route.pos = atomic.AddUint32(&app.routesCount, 1) route.Method = method // Add route to the stack app.stack[m] = append(app.stack[m], route) @@ -469,7 +464,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) { } // Execute onRoute hooks & change latestRoute if not adding mounted route - if !mounted { + if !route.mount { app.mutex.Lock() app.latestRoute = route if err := app.hooks.executeOnRouteHooks(*route); err != nil { @@ -481,38 +476,59 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) { // buildTree build the prefix tree from the previously registered routes func (app *App) buildTree() *App { + // If routes haven't been refreshed, nothing to do if !app.routesRefreshed { return app } - // loop all the methods and stacks and create the prefix tree - for m := range app.config.RequestMethods { - tsMap := make(map[string][]*Route) - for _, route := range app.stack[m] { - treePath := "" + // 1) First loop: determine all possible 3-char prefixes ("treePaths") for each method + for method := range app.config.RequestMethods { + prefixSet := map[string]struct{}{ + "": {}, + } + for _, route := range app.stack[method] { if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 { - treePath = route.routeParser.segs[0].Const[:3] + prefix := route.routeParser.segs[0].Const[:3] + prefixSet[prefix] = struct{}{} } - // create tree stack - tsMap[treePath] = append(tsMap[treePath], route) } - app.treeStack[m] = tsMap + tsMap := make(map[string][]*Route, len(prefixSet)) + for prefix := range prefixSet { + tsMap[prefix] = nil + } + app.treeStack[method] = tsMap } - // loop the methods and tree stacks and add global stack and sort everything - for m := range app.config.RequestMethods { - tsMap := app.treeStack[m] - for treePart := range tsMap { - if treePart != "" { - // merge global tree routes in current tree stack - tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...)) + // 2) Second loop: for each method and each discovered treePath, assign matching routes + for method := range app.config.RequestMethods { + // get the map of buckets for this method + tsMap := app.treeStack[method] + + // for every treePath key (including the empty one) + for treePath := range tsMap { + // iterate all routes of this method + for _, route := range app.stack[method] { + // compute this route's own prefix ("" or first 3 chars) + routePath := "" + if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 { + routePath = route.routeParser.segs[0].Const[:3] + } + + // if it's a global route, assign to every bucket + if routePath == "" { + tsMap[treePath] = append(tsMap[treePath], route) + // otherwise only assign if this route's prefix matches the current bucket's key + } else if routePath == treePath { + tsMap[treePath] = append(tsMap[treePath], route) + } } - // sort tree slices with the positions - slc := tsMap[treePart] - sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos }) + + // after collecting, dedupe the bucket if it's not the global one + tsMap[treePath] = uniqueRouteStack(tsMap[treePath]) } } + + // reset the flag and return app.routesRefreshed = false - return app } diff --git a/router_test.go b/router_test.go index 8d1e40cb..2c4424f8 100644 --- a/router_test.go +++ b/router_test.go @@ -20,7 +20,10 @@ import ( "github.com/valyala/fasthttp" ) -var routesFixture routeJSON +var ( + routesFixture routeJSON + cssDir = "./.github/testdata/fs/css" +) func init() { dat, err := os.ReadFile("./.github/testdata/testRoutes.json") @@ -354,9 +357,8 @@ func Test_Router_Handler_Catch_Error(t *testing.T) { func Test_Route_Static_Root(t *testing.T) { t.Parallel() - dir := "./.github/testdata/fs/css" app := New() - app.Static("/", dir, Static{ + app.Static("/", cssDir, Static{ Browse: true, }) @@ -373,7 +375,7 @@ func Test_Route_Static_Root(t *testing.T) { utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) app = New() - app.Static("/", dir) + app.Static("/", cssDir) resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -391,9 +393,8 @@ func Test_Route_Static_Root(t *testing.T) { func Test_Route_Static_HasPrefix(t *testing.T) { t.Parallel() - dir := "./.github/testdata/fs/css" app := New() - app.Static("/static", dir, Static{ + app.Static("/static", cssDir, Static{ Browse: true, }) @@ -414,7 +415,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) app = New() - app.Static("/static/", dir, Static{ + app.Static("/static/", cssDir, Static{ Browse: true, }) @@ -435,7 +436,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) app = New() - app.Static("/static", dir) + app.Static("/static", cssDir) resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -454,7 +455,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) app = New() - app.Static("/static/", dir) + app.Static("/static/", cssDir) resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -474,6 +475,8 @@ func Test_Route_Static_HasPrefix(t *testing.T) { } func Test_Router_NotFound(t *testing.T) { + t.Parallel() + app := New() app.Use(func(c *Ctx) error { return c.Next() @@ -491,6 +494,8 @@ func Test_Router_NotFound(t *testing.T) { } func Test_Router_NotFound_HTML_Inject(t *testing.T) { + t.Parallel() + app := New() app.Use(func(c *Ctx) error { return c.Next() @@ -507,6 +512,31 @@ func Test_Router_NotFound_HTML_Inject(t *testing.T) { utils.AssertEqual(t, "Cannot DELETE /does/not/exist<script>alert('foo');</script>", string(c.Response.Body())) } +func Test_Router_Mount_n_Static(t *testing.T) { + t.Parallel() + + app := New() + + app.Static("/static", cssDir, Static{Browse: true}) + app.Get("/", func(c *Ctx) error { + return c.SendString("Home") + }) + + subApp := New() + app.Mount("/mount", subApp) + subApp.Get("/test", func(c *Ctx) error { + return c.SendString("Hello from /test") + }) + + app.Use(func(c *Ctx) error { + return c.Status(StatusNotFound).SendString("Not Found") + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") +} + ////////////////////////////////////////////// ///////////////// BENCHMARKS ///////////////// //////////////////////////////////////////////