🐛 Fix routing with mount and static (#3454)

* fix routing with mount and static

[Bug]: Static server in sub app does not work #3104
https://github.com/gofiber/fiber/issues/3104
[Bug]: When mounting a subapp with mount, the static route is inaccessible. #3442
https://github.com/gofiber/fiber/issues/3442
This commit is contained in:
RW 2025-05-16 08:29:39 +02:00 committed by GitHub
parent fccff19606
commit 94e30d7124
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 96 additions and 66 deletions

2
app.go
View File

@ -93,8 +93,6 @@ type App struct {
treeStack []map[string][]*Route treeStack []map[string][]*Route
// contains the information if the route stack has been changed to build the optimized tree // contains the information if the route stack has been changed to build the optimized tree
routesRefreshed bool routesRefreshed bool
// Amount of registered routes
routesCount uint32
// Amount of registered handlers // Amount of registered handlers
handlersCount uint32 handlersCount uint32
// Ctx pool // Ctx pool

View File

@ -174,7 +174,6 @@ func (app *App) processSubAppsRoutes() {
} }
} }
var handlersCount uint32 var handlersCount uint32
var routePos uint32
// Iterate over the stack of the parent app // Iterate over the stack of the parent app
for m := range app.stack { for m := range app.stack {
// Iterate over each route in the stack // Iterate over each route in the stack
@ -183,9 +182,6 @@ func (app *App) processSubAppsRoutes() {
route := app.stack[m][i] route := app.stack[m][i]
// Check if the route has a mounted app // Check if the route has a mounted app
if !route.mount { if !route.mount {
routePos++
// If not, update the route's position and continue
route.pos = routePos
if !route.use || (route.use && m == 0) { if !route.use || (route.use && m == 0) {
handlersCount += uint32(len(route.Handlers)) handlersCount += uint32(len(route.Handlers))
} }
@ -214,11 +210,7 @@ func (app *App) processSubAppsRoutes() {
copy(newStack[i+len(subRoutes):], app.stack[m][i+1:]) copy(newStack[i+len(subRoutes):], app.stack[m][i+1:])
app.stack[m] = newStack app.stack[m] = newStack
// Decrease the parent app's route count to account for the mounted app's original route
atomic.AddUint32(&app.routesCount, ^uint32(0))
i-- i--
// Increase the parent app's route count to account for the sub-app's routes
atomic.AddUint32(&app.routesCount, uint32(len(subRoutes)))
// Mark the parent app's routes as refreshed // Mark the parent app's routes as refreshed
app.routesRefreshed = true app.routesRefreshed = true

View File

@ -89,7 +89,6 @@ func Test_App_Mount_Nested(t *testing.T) {
utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, uint32(6), app.handlersCount) utils.AssertEqual(t, uint32(6), app.handlersCount)
utils.AssertEqual(t, uint32(6), app.routesCount)
} }
// go test -run Test_App_Mount_Express_Behavior // go test -run Test_App_Mount_Express_Behavior
@ -139,7 +138,6 @@ func Test_App_Mount_Express_Behavior(t *testing.T) {
testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound) testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound)
utils.AssertEqual(t, uint32(17), app.handlersCount) utils.AssertEqual(t, uint32(17), app.handlersCount)
utils.AssertEqual(t, uint32(16+9), app.routesCount)
} }
// go test -run Test_App_Mount_RoutePositions // go test -run Test_App_Mount_RoutePositions
@ -195,19 +193,15 @@ func Test_App_Mount_RoutePositions(t *testing.T) {
utils.AssertEqual(t, true, routeStackGET[1].use) utils.AssertEqual(t, true, routeStackGET[1].use)
utils.AssertEqual(t, "/", routeStackGET[1].path) utils.AssertEqual(t, "/", routeStackGET[1].path)
utils.AssertEqual(t, true, routeStackGET[0].pos < routeStackGET[1].pos, "wrong position of route 0")
utils.AssertEqual(t, false, routeStackGET[2].use) utils.AssertEqual(t, false, routeStackGET[2].use)
utils.AssertEqual(t, "/bar", routeStackGET[2].path) utils.AssertEqual(t, "/bar", routeStackGET[2].path)
utils.AssertEqual(t, true, routeStackGET[1].pos < routeStackGET[2].pos, "wrong position of route 1")
utils.AssertEqual(t, true, routeStackGET[3].use) utils.AssertEqual(t, true, routeStackGET[3].use)
utils.AssertEqual(t, "/", routeStackGET[3].path) utils.AssertEqual(t, "/", routeStackGET[3].path)
utils.AssertEqual(t, true, routeStackGET[2].pos < routeStackGET[3].pos, "wrong position of route 2")
utils.AssertEqual(t, false, routeStackGET[4].use) utils.AssertEqual(t, false, routeStackGET[4].use)
utils.AssertEqual(t, "/subapp2/world", routeStackGET[4].path) utils.AssertEqual(t, "/subapp2/world", routeStackGET[4].path)
utils.AssertEqual(t, true, routeStackGET[3].pos < routeStackGET[4].pos, "wrong position of route 3")
utils.AssertEqual(t, 5, len(routeStackGET)) utils.AssertEqual(t, 5, len(routeStackGET))
} }

View File

@ -7,7 +7,6 @@ package fiber
import ( import (
"fmt" "fmt"
"html" "html"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -47,9 +46,8 @@ type Router interface {
// Route is a struct that holds all metadata for each registered handler. // Route is a struct that holds all metadata for each registered handler.
type Route struct { type Route struct {
// ### important: always keep in sync with the copy method "app.copyRoute" ### // ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ###
// Data for routing // Data for routing
pos uint32 // Position in stack -> important for the sort of the matched routes
use bool // USE matches path prefixes use bool // USE matches path prefixes
mount bool // Indicated a mounted app on a specific route mount bool // Indicated a mounted app on a specific route
star bool // Path equals '*' star bool // Path equals '*'
@ -215,9 +213,6 @@ func (*App) copyRoute(route *Route) *Route {
path: route.path, path: route.path,
routeParser: route.routeParser, routeParser: route.routeParser,
// misc
pos: route.pos,
// Public data // Public data
Path: route.Path, Path: route.Path,
Params: route.Params, Params: route.Params,
@ -298,11 +293,11 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
for _, m := range app.config.RequestMethods { for _, m := range app.config.RequestMethods {
// Create a route copy to avoid duplicates during compression // Create a route copy to avoid duplicates during compression
r := route r := route
app.addRoute(m, &r, isMount) app.addRoute(m, &r)
} }
} else { } else {
// Add route to stack // Add route to stack
app.addRoute(method, &route, isMount) app.addRoute(method, &route)
} }
} }
@ -428,12 +423,20 @@ func (app *App) registerStatic(prefix, root string, config ...Static) {
// Create route metadata without pointer // Create route metadata without pointer
route := Route{ route := Route{
// Router booleans // Router booleans
use: true, use: true,
root: isRoot, mount: false,
star: isStar,
root: isRoot,
// Path data
path: prefix, path: prefix,
// Group data
group: nil,
// Public data // Public data
Method: MethodGet,
Path: prefix, Path: prefix,
Method: MethodGet,
Handlers: []Handler{handler}, Handlers: []Handler{handler},
} }
// Increment global handler count // Increment global handler count
@ -444,13 +447,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) {
app.addRoute(MethodHead, &route) app.addRoute(MethodHead, &route)
} }
func (app *App) addRoute(method string, route *Route, isMounted ...bool) { func (app *App) addRoute(method string, route *Route) {
// Check mounted routes
var mounted bool
if len(isMounted) > 0 {
mounted = isMounted[0]
}
// Get unique HTTP method identifier // Get unique HTTP method identifier
m := app.methodInt(method) m := app.methodInt(method)
@ -460,8 +457,6 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
preRoute := app.stack[m][l-1] preRoute := app.stack[m][l-1]
preRoute.Handlers = append(preRoute.Handlers, route.Handlers...) preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
} else { } else {
// Increment global route position
route.pos = atomic.AddUint32(&app.routesCount, 1)
route.Method = method route.Method = method
// Add route to the stack // Add route to the stack
app.stack[m] = append(app.stack[m], route) app.stack[m] = append(app.stack[m], route)
@ -469,7 +464,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
} }
// Execute onRoute hooks & change latestRoute if not adding mounted route // Execute onRoute hooks & change latestRoute if not adding mounted route
if !mounted { if !route.mount {
app.mutex.Lock() app.mutex.Lock()
app.latestRoute = route app.latestRoute = route
if err := app.hooks.executeOnRouteHooks(*route); err != nil { if err := app.hooks.executeOnRouteHooks(*route); err != nil {
@ -481,38 +476,59 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
// buildTree build the prefix tree from the previously registered routes // buildTree build the prefix tree from the previously registered routes
func (app *App) buildTree() *App { func (app *App) buildTree() *App {
// If routes haven't been refreshed, nothing to do
if !app.routesRefreshed { if !app.routesRefreshed {
return app return app
} }
// loop all the methods and stacks and create the prefix tree // 1) First loop: determine all possible 3-char prefixes ("treePaths") for each method
for m := range app.config.RequestMethods { for method := range app.config.RequestMethods {
tsMap := make(map[string][]*Route) prefixSet := map[string]struct{}{
for _, route := range app.stack[m] { "": {},
treePath := "" }
for _, route := range app.stack[method] {
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 { if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
treePath = route.routeParser.segs[0].Const[:3] prefix := route.routeParser.segs[0].Const[:3]
prefixSet[prefix] = struct{}{}
} }
// create tree stack
tsMap[treePath] = append(tsMap[treePath], route)
} }
app.treeStack[m] = tsMap tsMap := make(map[string][]*Route, len(prefixSet))
for prefix := range prefixSet {
tsMap[prefix] = nil
}
app.treeStack[method] = tsMap
} }
// loop the methods and tree stacks and add global stack and sort everything // 2) Second loop: for each method and each discovered treePath, assign matching routes
for m := range app.config.RequestMethods { for method := range app.config.RequestMethods {
tsMap := app.treeStack[m] // get the map of buckets for this method
for treePart := range tsMap { tsMap := app.treeStack[method]
if treePart != "" {
// merge global tree routes in current tree stack // for every treePath key (including the empty one)
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...)) for treePath := range tsMap {
// iterate all routes of this method
for _, route := range app.stack[method] {
// compute this route's own prefix ("" or first 3 chars)
routePath := ""
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
routePath = route.routeParser.segs[0].Const[:3]
}
// if it's a global route, assign to every bucket
if routePath == "" {
tsMap[treePath] = append(tsMap[treePath], route)
// otherwise only assign if this route's prefix matches the current bucket's key
} else if routePath == treePath {
tsMap[treePath] = append(tsMap[treePath], route)
}
} }
// sort tree slices with the positions
slc := tsMap[treePart] // after collecting, dedupe the bucket if it's not the global one
sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos }) tsMap[treePath] = uniqueRouteStack(tsMap[treePath])
} }
} }
// reset the flag and return
app.routesRefreshed = false app.routesRefreshed = false
return app return app
} }

View File

@ -20,7 +20,10 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
var routesFixture routeJSON var (
routesFixture routeJSON
cssDir = "./.github/testdata/fs/css"
)
func init() { func init() {
dat, err := os.ReadFile("./.github/testdata/testRoutes.json") dat, err := os.ReadFile("./.github/testdata/testRoutes.json")
@ -354,9 +357,8 @@ func Test_Router_Handler_Catch_Error(t *testing.T) {
func Test_Route_Static_Root(t *testing.T) { func Test_Route_Static_Root(t *testing.T) {
t.Parallel() t.Parallel()
dir := "./.github/testdata/fs/css"
app := New() app := New()
app.Static("/", dir, Static{ app.Static("/", cssDir, Static{
Browse: true, Browse: true,
}) })
@ -373,7 +375,7 @@ func Test_Route_Static_Root(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New() app = New()
app.Static("/", dir) app.Static("/", cssDir)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -391,9 +393,8 @@ func Test_Route_Static_Root(t *testing.T) {
func Test_Route_Static_HasPrefix(t *testing.T) { func Test_Route_Static_HasPrefix(t *testing.T) {
t.Parallel() t.Parallel()
dir := "./.github/testdata/fs/css"
app := New() app := New()
app.Static("/static", dir, Static{ app.Static("/static", cssDir, Static{
Browse: true, Browse: true,
}) })
@ -414,7 +415,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New() app = New()
app.Static("/static/", dir, Static{ app.Static("/static/", cssDir, Static{
Browse: true, Browse: true,
}) })
@ -435,7 +436,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New() app = New()
app.Static("/static", dir) app.Static("/static", cssDir)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -454,7 +455,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New() app = New()
app.Static("/static/", dir) app.Static("/static/", cssDir)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -474,6 +475,8 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
} }
func Test_Router_NotFound(t *testing.T) { func Test_Router_NotFound(t *testing.T) {
t.Parallel()
app := New() app := New()
app.Use(func(c *Ctx) error { app.Use(func(c *Ctx) error {
return c.Next() return c.Next()
@ -491,6 +494,8 @@ func Test_Router_NotFound(t *testing.T) {
} }
func Test_Router_NotFound_HTML_Inject(t *testing.T) { func Test_Router_NotFound_HTML_Inject(t *testing.T) {
t.Parallel()
app := New() app := New()
app.Use(func(c *Ctx) error { app.Use(func(c *Ctx) error {
return c.Next() return c.Next()
@ -507,6 +512,31 @@ func Test_Router_NotFound_HTML_Inject(t *testing.T) {
utils.AssertEqual(t, "Cannot DELETE /does/not/exist&lt;script&gt;alert(&#39;foo&#39;);&lt;/script&gt;", string(c.Response.Body())) utils.AssertEqual(t, "Cannot DELETE /does/not/exist&lt;script&gt;alert(&#39;foo&#39;);&lt;/script&gt;", string(c.Response.Body()))
} }
func Test_Router_Mount_n_Static(t *testing.T) {
t.Parallel()
app := New()
app.Static("/static", cssDir, Static{Browse: true})
app.Get("/", func(c *Ctx) error {
return c.SendString("Home")
})
subApp := New()
app.Mount("/mount", subApp)
subApp.Get("/test", func(c *Ctx) error {
return c.SendString("Hello from /test")
})
app.Use(func(c *Ctx) error {
return c.Status(StatusNotFound).SendString("Not Found")
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
}
////////////////////////////////////////////// //////////////////////////////////////////////
///////////////// BENCHMARKS ///////////////// ///////////////// BENCHMARKS /////////////////
////////////////////////////////////////////// //////////////////////////////////////////////