mirror of
https://github.com/gofiber/fiber.git
synced 2025-05-31 11:52:41 +00:00
🐛 Fix routing with mount and static (#3454)
* fix routing with mount and static [Bug]: Static server in sub app does not work #3104 https://github.com/gofiber/fiber/issues/3104 [Bug]: When mounting a subapp with mount, the static route is inaccessible. #3442 https://github.com/gofiber/fiber/issues/3442
This commit is contained in:
parent
fccff19606
commit
94e30d7124
2
app.go
2
app.go
@ -93,8 +93,6 @@ type App struct {
|
||||
treeStack []map[string][]*Route
|
||||
// contains the information if the route stack has been changed to build the optimized tree
|
||||
routesRefreshed bool
|
||||
// Amount of registered routes
|
||||
routesCount uint32
|
||||
// Amount of registered handlers
|
||||
handlersCount uint32
|
||||
// Ctx pool
|
||||
|
8
mount.go
8
mount.go
@ -174,7 +174,6 @@ func (app *App) processSubAppsRoutes() {
|
||||
}
|
||||
}
|
||||
var handlersCount uint32
|
||||
var routePos uint32
|
||||
// Iterate over the stack of the parent app
|
||||
for m := range app.stack {
|
||||
// Iterate over each route in the stack
|
||||
@ -183,9 +182,6 @@ func (app *App) processSubAppsRoutes() {
|
||||
route := app.stack[m][i]
|
||||
// Check if the route has a mounted app
|
||||
if !route.mount {
|
||||
routePos++
|
||||
// If not, update the route's position and continue
|
||||
route.pos = routePos
|
||||
if !route.use || (route.use && m == 0) {
|
||||
handlersCount += uint32(len(route.Handlers))
|
||||
}
|
||||
@ -214,11 +210,7 @@ func (app *App) processSubAppsRoutes() {
|
||||
copy(newStack[i+len(subRoutes):], app.stack[m][i+1:])
|
||||
app.stack[m] = newStack
|
||||
|
||||
// Decrease the parent app's route count to account for the mounted app's original route
|
||||
atomic.AddUint32(&app.routesCount, ^uint32(0))
|
||||
i--
|
||||
// Increase the parent app's route count to account for the sub-app's routes
|
||||
atomic.AddUint32(&app.routesCount, uint32(len(subRoutes)))
|
||||
|
||||
// Mark the parent app's routes as refreshed
|
||||
app.routesRefreshed = true
|
||||
|
@ -89,7 +89,6 @@ func Test_App_Mount_Nested(t *testing.T) {
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
|
||||
utils.AssertEqual(t, uint32(6), app.handlersCount)
|
||||
utils.AssertEqual(t, uint32(6), app.routesCount)
|
||||
}
|
||||
|
||||
// go test -run Test_App_Mount_Express_Behavior
|
||||
@ -139,7 +138,6 @@ func Test_App_Mount_Express_Behavior(t *testing.T) {
|
||||
testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound)
|
||||
|
||||
utils.AssertEqual(t, uint32(17), app.handlersCount)
|
||||
utils.AssertEqual(t, uint32(16+9), app.routesCount)
|
||||
}
|
||||
|
||||
// go test -run Test_App_Mount_RoutePositions
|
||||
@ -195,19 +193,15 @@ func Test_App_Mount_RoutePositions(t *testing.T) {
|
||||
|
||||
utils.AssertEqual(t, true, routeStackGET[1].use)
|
||||
utils.AssertEqual(t, "/", routeStackGET[1].path)
|
||||
utils.AssertEqual(t, true, routeStackGET[0].pos < routeStackGET[1].pos, "wrong position of route 0")
|
||||
|
||||
utils.AssertEqual(t, false, routeStackGET[2].use)
|
||||
utils.AssertEqual(t, "/bar", routeStackGET[2].path)
|
||||
utils.AssertEqual(t, true, routeStackGET[1].pos < routeStackGET[2].pos, "wrong position of route 1")
|
||||
|
||||
utils.AssertEqual(t, true, routeStackGET[3].use)
|
||||
utils.AssertEqual(t, "/", routeStackGET[3].path)
|
||||
utils.AssertEqual(t, true, routeStackGET[2].pos < routeStackGET[3].pos, "wrong position of route 2")
|
||||
|
||||
utils.AssertEqual(t, false, routeStackGET[4].use)
|
||||
utils.AssertEqual(t, "/subapp2/world", routeStackGET[4].path)
|
||||
utils.AssertEqual(t, true, routeStackGET[3].pos < routeStackGET[4].pos, "wrong position of route 3")
|
||||
|
||||
utils.AssertEqual(t, 5, len(routeStackGET))
|
||||
}
|
||||
|
98
router.go
98
router.go
@ -7,7 +7,6 @@ package fiber
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@ -47,9 +46,8 @@ type Router interface {
|
||||
|
||||
// Route is a struct that holds all metadata for each registered handler.
|
||||
type Route struct {
|
||||
// ### important: always keep in sync with the copy method "app.copyRoute" ###
|
||||
// ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ###
|
||||
// Data for routing
|
||||
pos uint32 // Position in stack -> important for the sort of the matched routes
|
||||
use bool // USE matches path prefixes
|
||||
mount bool // Indicated a mounted app on a specific route
|
||||
star bool // Path equals '*'
|
||||
@ -215,9 +213,6 @@ func (*App) copyRoute(route *Route) *Route {
|
||||
path: route.path,
|
||||
routeParser: route.routeParser,
|
||||
|
||||
// misc
|
||||
pos: route.pos,
|
||||
|
||||
// Public data
|
||||
Path: route.Path,
|
||||
Params: route.Params,
|
||||
@ -298,11 +293,11 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
|
||||
for _, m := range app.config.RequestMethods {
|
||||
// Create a route copy to avoid duplicates during compression
|
||||
r := route
|
||||
app.addRoute(m, &r, isMount)
|
||||
app.addRoute(m, &r)
|
||||
}
|
||||
} else {
|
||||
// Add route to stack
|
||||
app.addRoute(method, &route, isMount)
|
||||
app.addRoute(method, &route)
|
||||
}
|
||||
}
|
||||
|
||||
@ -428,12 +423,20 @@ func (app *App) registerStatic(prefix, root string, config ...Static) {
|
||||
// Create route metadata without pointer
|
||||
route := Route{
|
||||
// Router booleans
|
||||
use: true,
|
||||
root: isRoot,
|
||||
use: true,
|
||||
mount: false,
|
||||
star: isStar,
|
||||
root: isRoot,
|
||||
|
||||
// Path data
|
||||
path: prefix,
|
||||
|
||||
// Group data
|
||||
group: nil,
|
||||
|
||||
// Public data
|
||||
Method: MethodGet,
|
||||
Path: prefix,
|
||||
Method: MethodGet,
|
||||
Handlers: []Handler{handler},
|
||||
}
|
||||
// Increment global handler count
|
||||
@ -444,13 +447,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) {
|
||||
app.addRoute(MethodHead, &route)
|
||||
}
|
||||
|
||||
func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
||||
// Check mounted routes
|
||||
var mounted bool
|
||||
if len(isMounted) > 0 {
|
||||
mounted = isMounted[0]
|
||||
}
|
||||
|
||||
func (app *App) addRoute(method string, route *Route) {
|
||||
// Get unique HTTP method identifier
|
||||
m := app.methodInt(method)
|
||||
|
||||
@ -460,8 +457,6 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
||||
preRoute := app.stack[m][l-1]
|
||||
preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
|
||||
} else {
|
||||
// Increment global route position
|
||||
route.pos = atomic.AddUint32(&app.routesCount, 1)
|
||||
route.Method = method
|
||||
// Add route to the stack
|
||||
app.stack[m] = append(app.stack[m], route)
|
||||
@ -469,7 +464,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
||||
}
|
||||
|
||||
// Execute onRoute hooks & change latestRoute if not adding mounted route
|
||||
if !mounted {
|
||||
if !route.mount {
|
||||
app.mutex.Lock()
|
||||
app.latestRoute = route
|
||||
if err := app.hooks.executeOnRouteHooks(*route); err != nil {
|
||||
@ -481,38 +476,59 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
||||
|
||||
// buildTree build the prefix tree from the previously registered routes
|
||||
func (app *App) buildTree() *App {
|
||||
// If routes haven't been refreshed, nothing to do
|
||||
if !app.routesRefreshed {
|
||||
return app
|
||||
}
|
||||
|
||||
// loop all the methods and stacks and create the prefix tree
|
||||
for m := range app.config.RequestMethods {
|
||||
tsMap := make(map[string][]*Route)
|
||||
for _, route := range app.stack[m] {
|
||||
treePath := ""
|
||||
// 1) First loop: determine all possible 3-char prefixes ("treePaths") for each method
|
||||
for method := range app.config.RequestMethods {
|
||||
prefixSet := map[string]struct{}{
|
||||
"": {},
|
||||
}
|
||||
for _, route := range app.stack[method] {
|
||||
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
|
||||
treePath = route.routeParser.segs[0].Const[:3]
|
||||
prefix := route.routeParser.segs[0].Const[:3]
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
// create tree stack
|
||||
tsMap[treePath] = append(tsMap[treePath], route)
|
||||
}
|
||||
app.treeStack[m] = tsMap
|
||||
tsMap := make(map[string][]*Route, len(prefixSet))
|
||||
for prefix := range prefixSet {
|
||||
tsMap[prefix] = nil
|
||||
}
|
||||
app.treeStack[method] = tsMap
|
||||
}
|
||||
|
||||
// loop the methods and tree stacks and add global stack and sort everything
|
||||
for m := range app.config.RequestMethods {
|
||||
tsMap := app.treeStack[m]
|
||||
for treePart := range tsMap {
|
||||
if treePart != "" {
|
||||
// merge global tree routes in current tree stack
|
||||
tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...))
|
||||
// 2) Second loop: for each method and each discovered treePath, assign matching routes
|
||||
for method := range app.config.RequestMethods {
|
||||
// get the map of buckets for this method
|
||||
tsMap := app.treeStack[method]
|
||||
|
||||
// for every treePath key (including the empty one)
|
||||
for treePath := range tsMap {
|
||||
// iterate all routes of this method
|
||||
for _, route := range app.stack[method] {
|
||||
// compute this route's own prefix ("" or first 3 chars)
|
||||
routePath := ""
|
||||
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
|
||||
routePath = route.routeParser.segs[0].Const[:3]
|
||||
}
|
||||
|
||||
// if it's a global route, assign to every bucket
|
||||
if routePath == "" {
|
||||
tsMap[treePath] = append(tsMap[treePath], route)
|
||||
// otherwise only assign if this route's prefix matches the current bucket's key
|
||||
} else if routePath == treePath {
|
||||
tsMap[treePath] = append(tsMap[treePath], route)
|
||||
}
|
||||
}
|
||||
// sort tree slices with the positions
|
||||
slc := tsMap[treePart]
|
||||
sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos })
|
||||
|
||||
// after collecting, dedupe the bucket if it's not the global one
|
||||
tsMap[treePath] = uniqueRouteStack(tsMap[treePath])
|
||||
}
|
||||
}
|
||||
|
||||
// reset the flag and return
|
||||
app.routesRefreshed = false
|
||||
|
||||
return app
|
||||
}
|
||||
|
@ -20,7 +20,10 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var routesFixture routeJSON
|
||||
var (
|
||||
routesFixture routeJSON
|
||||
cssDir = "./.github/testdata/fs/css"
|
||||
)
|
||||
|
||||
func init() {
|
||||
dat, err := os.ReadFile("./.github/testdata/testRoutes.json")
|
||||
@ -354,9 +357,8 @@ func Test_Router_Handler_Catch_Error(t *testing.T) {
|
||||
func Test_Route_Static_Root(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := "./.github/testdata/fs/css"
|
||||
app := New()
|
||||
app.Static("/", dir, Static{
|
||||
app.Static("/", cssDir, Static{
|
||||
Browse: true,
|
||||
})
|
||||
|
||||
@ -373,7 +375,7 @@ func Test_Route_Static_Root(t *testing.T) {
|
||||
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
|
||||
|
||||
app = New()
|
||||
app.Static("/", dir)
|
||||
app.Static("/", cssDir)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
@ -391,9 +393,8 @@ func Test_Route_Static_Root(t *testing.T) {
|
||||
func Test_Route_Static_HasPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := "./.github/testdata/fs/css"
|
||||
app := New()
|
||||
app.Static("/static", dir, Static{
|
||||
app.Static("/static", cssDir, Static{
|
||||
Browse: true,
|
||||
})
|
||||
|
||||
@ -414,7 +415,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
|
||||
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
|
||||
|
||||
app = New()
|
||||
app.Static("/static/", dir, Static{
|
||||
app.Static("/static/", cssDir, Static{
|
||||
Browse: true,
|
||||
})
|
||||
|
||||
@ -435,7 +436,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
|
||||
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
|
||||
|
||||
app = New()
|
||||
app.Static("/static", dir)
|
||||
app.Static("/static", cssDir)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
@ -454,7 +455,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
|
||||
utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color"))
|
||||
|
||||
app = New()
|
||||
app.Static("/static/", dir)
|
||||
app.Static("/static/", cssDir)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
@ -474,6 +475,8 @@ func Test_Route_Static_HasPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Router_NotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := New()
|
||||
app.Use(func(c *Ctx) error {
|
||||
return c.Next()
|
||||
@ -491,6 +494,8 @@ func Test_Router_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Router_NotFound_HTML_Inject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := New()
|
||||
app.Use(func(c *Ctx) error {
|
||||
return c.Next()
|
||||
@ -507,6 +512,31 @@ func Test_Router_NotFound_HTML_Inject(t *testing.T) {
|
||||
utils.AssertEqual(t, "Cannot DELETE /does/not/exist<script>alert('foo');</script>", string(c.Response.Body()))
|
||||
}
|
||||
|
||||
func Test_Router_Mount_n_Static(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := New()
|
||||
|
||||
app.Static("/static", cssDir, Static{Browse: true})
|
||||
app.Get("/", func(c *Ctx) error {
|
||||
return c.SendString("Home")
|
||||
})
|
||||
|
||||
subApp := New()
|
||||
app.Mount("/mount", subApp)
|
||||
subApp.Get("/test", func(c *Ctx) error {
|
||||
return c.SendString("Hello from /test")
|
||||
})
|
||||
|
||||
app.Use(func(c *Ctx) error {
|
||||
return c.Status(StatusNotFound).SendString("Not Found")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////
|
||||
///////////////// BENCHMARKS /////////////////
|
||||
//////////////////////////////////////////////
|
||||
|
Loading…
x
Reference in New Issue
Block a user