🐛 Fix routing with mount and static (#3454)

* fix routing with mount and static

[Bug]: Static server in sub app does not work #3104
https://github.com/gofiber/fiber/issues/3104
[Bug]: When mounting a subapp with mount, the static route is inaccessible. #3442
https://github.com/gofiber/fiber/issues/3442
This commit is contained in:
RW 2025-05-16 08:29:39 +02:00 committed by GitHub
parent fccff19606
commit 94e30d7124
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 96 additions and 66 deletions

2
app.go
View File

@ -93,8 +93,6 @@ type App struct {
treeStack []map[string][]*Route
// contains the information if the route stack has been changed to build the optimized tree
routesRefreshed bool
// Amount of registered routes
routesCount uint32
// Amount of registered handlers
handlersCount uint32
// Ctx pool

View File

@ -174,7 +174,6 @@ func (app *App) processSubAppsRoutes() {
}
}
var handlersCount uint32
var routePos uint32
// Iterate over the stack of the parent app
for m := range app.stack {
// Iterate over each route in the stack
@ -183,9 +182,6 @@ func (app *App) processSubAppsRoutes() {
route := app.stack[m][i]
// Check if the route has a mounted app
if !route.mount {
routePos++
// If not, update the route's position and continue
route.pos = routePos
if !route.use || (route.use && m == 0) {
handlersCount += uint32(len(route.Handlers))
}
@ -214,11 +210,7 @@ func (app *App) processSubAppsRoutes() {
copy(newStack[i+len(subRoutes):], app.stack[m][i+1:])
app.stack[m] = newStack
// Decrease the parent app's route count to account for the mounted app's original route
atomic.AddUint32(&app.routesCount, ^uint32(0))
i--
// Increase the parent app's route count to account for the sub-app's routes
atomic.AddUint32(&app.routesCount, uint32(len(subRoutes)))
// Mark the parent app's routes as refreshed
app.routesRefreshed = true

View File

@ -89,7 +89,6 @@ func Test_App_Mount_Nested(t *testing.T) {
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, uint32(6), app.handlersCount)
utils.AssertEqual(t, uint32(6), app.routesCount)
}
// go test -run Test_App_Mount_Express_Behavior
@ -139,7 +138,6 @@ func Test_App_Mount_Express_Behavior(t *testing.T) {
testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound)
utils.AssertEqual(t, uint32(17), app.handlersCount)
utils.AssertEqual(t, uint32(16+9), app.routesCount)
}
// go test -run Test_App_Mount_RoutePositions
@ -195,19 +193,15 @@ func Test_App_Mount_RoutePositions(t *testing.T) {
utils.AssertEqual(t, true, routeStackGET[1].use)
utils.AssertEqual(t, "/", routeStackGET[1].path)
utils.AssertEqual(t, true, routeStackGET[0].pos < routeStackGET[1].pos, "wrong position of route 0")
utils.AssertEqual(t, false, routeStackGET[2].use)
utils.AssertEqual(t, "/bar", routeStackGET[2].path)
utils.AssertEqual(t, true, routeStackGET[1].pos < routeStackGET[2].pos, "wrong position of route 1")
utils.AssertEqual(t, true, routeStackGET[3].use)
utils.AssertEqual(t, "/", routeStackGET[3].path)
utils.AssertEqual(t, true, routeStackGET[2].pos < routeStackGET[3].pos, "wrong position of route 2")
utils.AssertEqual(t, false, routeStackGET[4].use)
utils.AssertEqual(t, "/subapp2/world", routeStackGET[4].path)
utils.AssertEqual(t, true, routeStackGET[3].pos < routeStackGET[4].pos, "wrong position of route 3")
utils.AssertEqual(t, 5, len(routeStackGET))
}

View File

@ -7,7 +7,6 @@ package fiber
import (
"fmt"
"html"
"sort"
"strconv"
"strings"
"sync/atomic"
@ -47,9 +46,8 @@ type Router interface {
// Route is a struct that holds all metadata for each registered handler.
type Route struct {
// ### important: always keep in sync with the copy method "app.copyRoute" ###
// ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ###
// Data for routing
pos uint32 // Position in stack -> important for the sort of the matched routes
use bool // USE matches path prefixes
mount bool // Indicated a mounted app on a specific route
star bool // Path equals '*'
@ -215,9 +213,6 @@ func (*App) copyRoute(route *Route) *Route {
path: route.path,
routeParser: route.routeParser,
// misc
pos: route.pos,
// Public data
Path: route.Path,
Params: route.Params,
@ -298,11 +293,11 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
for _, m := range app.config.RequestMethods {
// Create a route copy to avoid duplicates during compression
r := route
app.addRoute(m, &r, isMount)
app.addRoute(m, &r)
}
} else {
// Add route to stack
app.addRoute(method, &route, isMount)
app.addRoute(method, &route)
}
}
@ -428,12 +423,20 @@ func (app *App) registerStatic(prefix, root string, config ...Static) {
// Create route metadata without pointer
route := Route{
// Router booleans
use: true,
root: isRoot,
use: true,
mount: false,
star: isStar,
root: isRoot,
// Path data
path: prefix,
// Group data
group: nil,
// Public data
Method: MethodGet,
Path: prefix,
Method: MethodGet,
Handlers: []Handler{handler},
}
// Increment global handler count
@ -444,13 +447,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) {
app.addRoute(MethodHead, &route)
}
func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
// Check mounted routes
var mounted bool
if len(isMounted) > 0 {
mounted = isMounted[0]
}
func (app *App) addRoute(method string, route *Route) {
// Get unique HTTP method identifier
m := app.methodInt(method)
@ -460,8 +457,6 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
preRoute := app.stack[m][l-1]
preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
} else {
// Increment global route position
route.pos = atomic.AddUint32(&app.routesCount, 1)
route.Method = method
// Add route to the stack
app.stack[m] = append(app.stack[m], route)
@ -469,7 +464,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
}
// Execute onRoute hooks & change latestRoute if not adding mounted route
if !mounted {
if !route.mount {
app.mutex.Lock()
app.latestRoute = route
if err := app.hooks.executeOnRouteHooks(*route); err != nil {
@ -481,38 +476,59 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
// buildTree build the prefix tree from the previously registered routes
func (app *App) buildTree() *App {
// If routes haven't been refreshed, nothing to do
if !app.routesRefreshed {
return app
}
// loop all the methods and stacks and create the prefix tree
for m := range app.config.RequestMethods {
tsMap := make(map[string][]*Route)
for _, route := range app.stack[m] {
treePath := ""
// 1) First loop: determine all possible 3-char prefixes ("treePaths") for each method
for method := range app.config.RequestMethods {
prefixSet := map[string]struct{}{
"": {},
}
for _, route := range app.stack[method] {
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
treePath = route.routeParser.segs[0].Const[:3]
prefix := route.routeParser.segs[0].Const[:3]
prefixSet[prefix] = struct{}{}
}
// create tree stack
tsMap[treePath] = append(tsMap[treePath], route)
}
app.treeStack[m] = tsMap
tsMap := make(map[string][]*Route, len(prefixSet))
for prefix := range prefixSet {
tsMap[prefix] = nil
}
app.treeStack[method] = tsMap
}
// loop the methods and tree stacks and add global stack and sort everything
for m := range app.config.RequestMethods {
tsMap := app.treeStack[m]
for treePart := range tsMap {
if treePart != "" {
// merge global tree routes in current tree stack
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...))
// 2) Second loop: for each method and each discovered treePath, assign matching routes
for method := range app.config.RequestMethods {
// get the map of buckets for this method
tsMap := app.treeStack[method]
// for every treePath key (including the empty one)
for treePath := range tsMap {
// iterate all routes of this method
for _, route := range app.stack[method] {
// compute this route's own prefix ("" or first 3 chars)
routePath := ""
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
routePath = route.routeParser.segs[0].Const[:3]
}
// if it's a global route, assign to every bucket
if routePath == "" {
tsMap[treePath] = append(tsMap[treePath], route)
// otherwise only assign if this route's prefix matches the current bucket's key
} else if routePath == treePath {
tsMap[treePath] = append(tsMap[treePath], route)
}
}
// sort tree slices with the positions
slc := tsMap[treePart]
sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos })
// after collecting, dedupe the bucket if it's not the global one
tsMap[treePath] = uniqueRouteStack(tsMap[treePath])
}
}
// reset the flag and return
app.routesRefreshed = false
return app
}

View File

@ -20,7 +20,10 @@ import (
"github.com/valyala/fasthttp"
)
var routesFixture routeJSON
var (
routesFixture routeJSON
cssDir = "./.github/testdata/fs/css"
)
func init() {
dat, err := os.ReadFile("./.github/testdata/testRoutes.json")
@ -354,9 +357,8 @@ func Test_Router_Handler_Catch_Error(t *testing.T) {
func Test_Route_Static_Root(t *testing.T) {
t.Parallel()
dir := "./.github/testdata/fs/css"
app := New()
app.Static("/", dir, Static{
app.Static("/", cssDir, Static{
Browse: true,
})
@ -373,7 +375,7 @@ func Test_Route_Static_Root(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New()
app.Static("/", dir)
app.Static("/", cssDir)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -391,9 +393,8 @@ func Test_Route_Static_Root(t *testing.T) {
func Test_Route_Static_HasPrefix(t *testing.T) {
t.Parallel()
dir := "./.github/testdata/fs/css"
app := New()
app.Static("/static", dir, Static{
app.Static("/static", cssDir, Static{
Browse: true,
})
@ -414,7 +415,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New()
app.Static("/static/", dir, Static{
app.Static("/static/", cssDir, Static{
Browse: true,
})
@ -435,7 +436,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New()
app.Static("/static", dir)
app.Static("/static", cssDir)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -454,7 +455,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
app = New()
app.Static("/static/", dir)
app.Static("/static/", cssDir)
resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
@ -474,6 +475,8 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
}
func Test_Router_NotFound(t *testing.T) {
t.Parallel()
app := New()
app.Use(func(c *Ctx) error {
return c.Next()
@ -491,6 +494,8 @@ func Test_Router_NotFound(t *testing.T) {
}
func Test_Router_NotFound_HTML_Inject(t *testing.T) {
t.Parallel()
app := New()
app.Use(func(c *Ctx) error {
return c.Next()
@ -507,6 +512,31 @@ func Test_Router_NotFound_HTML_Inject(t *testing.T) {
utils.AssertEqual(t, "Cannot DELETE /does/not/exist&lt;script&gt;alert(&#39;foo&#39;);&lt;/script&gt;", string(c.Response.Body()))
}
func Test_Router_Mount_n_Static(t *testing.T) {
t.Parallel()
app := New()
app.Static("/static", cssDir, Static{Browse: true})
app.Get("/", func(c *Ctx) error {
return c.SendString("Home")
})
subApp := New()
app.Mount("/mount", subApp)
subApp.Get("/test", func(c *Ctx) error {
return c.SendString("Hello from /test")
})
app.Use(func(c *Ctx) error {
return c.Status(StatusNotFound).SendString("Not Found")
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
}
//////////////////////////////////////////////
///////////////// BENCHMARKS /////////////////
//////////////////////////////////////////////