diff --git a/app.go b/app.go index fd290230..25de1f14 100644 --- a/app.go +++ b/app.go @@ -92,6 +92,8 @@ type App struct { stack [][]*Route // Route stack divided by HTTP methods and route prefixes treeStack []map[string][]*Route + // contains the information if the route stack has been changed to build the optimized tree + routesRefreshed bool // Amount of registered routes routesCount int // Amount of registered handlers @@ -538,7 +540,8 @@ func (app *App) Listener(ln net.Listener) error { addr, tls := lnMetadata(ln) return app.prefork(addr, tls) } - + // prepare the server for the start + app.startupProcess() // Print startup message if !app.config.DisableStartupMessage { app.startupMessage(ln.Addr().String(), false, "") @@ -561,6 +564,8 @@ func (app *App) Listen(addr string) error { if err != nil { return err } + // prepare the server for the start + app.startupProcess() // Print startup message if !app.config.DisableStartupMessage { app.startupMessage(ln.Addr().String(), false, "") @@ -599,6 +604,8 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { if err != nil { return err } + // prepare the server for the start + app.startupProcess() // Print startup message if !app.config.DisableStartupMessage { app.startupMessage(ln.Addr().String(), true, "") @@ -614,6 +621,8 @@ func (app *App) Config() Config { // Handler returns the server handler. func (app *App) Handler() fasthttp.RequestHandler { + // prepare the server for the start + app.startupProcess() return app.handler } @@ -669,6 +678,8 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response, if _, err = conn.r.Write(dump); err != nil { return nil, err } + // prepare the server for the start + app.startupProcess() // Serve conn to server channel := make(chan error) @@ -767,6 +778,15 @@ func (app *App) init() *App { return app } +// startupProcess Is the method which executes all the necessary processes just before the start of the server. +func (app *App) startupProcess() *App { + app.mutex.Lock() + app.buildTree() + app.mutex.Unlock() + return app +} + +// startupMessage prepares the startup message with the handler number, port, address and other information func (app *App) startupMessage(addr string, tls bool, pids string) { // ignore child processes if IsChild() { diff --git a/ctx_test.go b/ctx_test.go index ae41a060..b920c686 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -901,7 +901,7 @@ func Test_Ctx_InvalidMethod(t *testing.T) { fctx.Request.Header.SetMethod("InvalidMethod") fctx.Request.SetRequestURI("/") - app.handler(fctx) + app.Handler()(fctx) utils.AssertEqual(t, 400, fctx.Response.StatusCode()) utils.AssertEqual(t, []byte("Invalid http method"), fctx.Response.Body()) diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 12d90e26..2722662e 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -49,9 +49,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { func Test_CORS_Wildcard(t *testing.T) { // New fiber instance app := fiber.New() - // Get handler pointer - handler := app.Handler() - // OPTIONS (preflight) response headers when AllowOrigins is * app.Use(New(Config{ AllowOrigins: "*", @@ -60,6 +57,8 @@ func Test_CORS_Wildcard(t *testing.T) { ExposeHeaders: "X-Request-ID", AllowHeaders: "Authentication", })) + // Get handler pointer + handler := app.Handler() // Make request ctx := &fasthttp.RequestCtx{} @@ -90,12 +89,12 @@ func Test_CORS_Wildcard(t *testing.T) { func Test_CORS_Subdomain(t *testing.T) { // New fiber instance app := fiber.New() - // Get handler pointer - handler := app.Handler() - // OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain app.Use("/", New(Config{AllowOrigins: "http://*.example.com"})) + // Get handler pointer + handler := app.Handler() + // Make request with disallowed origin ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") diff --git a/prefork_test.go b/prefork_test.go index 24a25ee6..0a6723a9 100644 --- a/prefork_test.go +++ b/prefork_test.go @@ -79,7 +79,7 @@ func Test_App_Prefork_Child_Process_Never_Show_Startup_Message(t *testing.T) { os.Stdout = w - New().startupMessage(":3000", false, "") + New().startupProcess().startupMessage(":3000", false, "") utils.AssertEqual(t, nil, w.Close()) diff --git a/router.go b/router.go index a1f2ad1b..6ae723b4 100644 --- a/router.go +++ b/router.go @@ -417,13 +417,15 @@ func (app *App) addRoute(method string, route *Route) { route.Method = method // Add route to the stack app.stack[m] = append(app.stack[m], route) + app.routesRefreshed = true } - // Build router tree - app.buildTree() } // buildTree build the prefix tree from the previously registered routes func (app *App) buildTree() *App { + if app.routesRefreshed == false { + return app + } // loop all the methods and stacks and create the prefix tree for m := range intMethod { app.treeStack[m] = make(map[string][]*Route) @@ -449,6 +451,7 @@ func (app *App) buildTree() *App { }) } } + app.routesRefreshed = false return app } diff --git a/router_test.go b/router_test.go index 7396c1fa..01509ba9 100644 --- a/router_test.go +++ b/router_test.go @@ -257,7 +257,7 @@ func Test_Router_Handler_SetETag(t *testing.T) { c := &fasthttp.RequestCtx{} - app.handler(c) + app.Handler()(c) utils.AssertEqual(t, `"13-1831710635"`, string(c.Response.Header.Peek(HeaderETag))) } @@ -274,7 +274,7 @@ func Test_Router_Handler_Catch_Error(t *testing.T) { c := &fasthttp.RequestCtx{} - app.handler(c) + app.Handler()(c) utils.AssertEqual(t, StatusInternalServerError, c.Response.Header.StatusCode()) } @@ -300,6 +300,7 @@ func Benchmark_App_MethodNotAllowed(b *testing.B) { } app.All("/this/is/a/", h) app.Get("/this/is/a/dummy/route/oke", h) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") @@ -307,7 +308,7 @@ func Benchmark_App_MethodNotAllowed(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } b.StopTimer() utils.AssertEqual(b, 405, c.Response.StatusCode()) @@ -322,6 +323,7 @@ func Benchmark_Router_NotFound(b *testing.B) { return c.Next() }) registerDummyRoutes(app) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") @@ -329,7 +331,7 @@ func Benchmark_Router_NotFound(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } utils.AssertEqual(b, 404, c.Response.StatusCode()) utils.AssertEqual(b, "Cannot DELETE /this/route/does/not/exist", string(c.Response.Body())) @@ -339,6 +341,7 @@ func Benchmark_Router_NotFound(b *testing.B) { func Benchmark_Router_Handler(b *testing.B) { app := New() registerDummyRoutes(app) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} @@ -348,7 +351,7 @@ func Benchmark_Router_Handler(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } } @@ -358,6 +361,7 @@ func Benchmark_Router_Handler_Strict_Case(b *testing.B) { CaseSensitive: true, }) registerDummyRoutes(app) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} @@ -367,7 +371,7 @@ func Benchmark_Router_Handler_Strict_Case(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } } @@ -379,13 +383,15 @@ func Benchmark_Router_Chain(b *testing.B) { } app.Get("/", handler, handler, handler, handler, handler, handler) + appHandler := app.Handler() + c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("GET") c.URI().SetPath("/") b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } } @@ -402,13 +408,14 @@ func Benchmark_Router_WithCompression(b *testing.B) { app.Get("/", handler) app.Get("/", handler) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("GET") c.URI().SetPath("/") b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } } @@ -416,6 +423,7 @@ func Benchmark_Router_WithCompression(b *testing.B) { func Benchmark_Router_Next(b *testing.B) { app := New() registerDummyRoutes(app) + app.startupProcess() request := &fasthttp.RequestCtx{} @@ -532,6 +540,7 @@ func Benchmark_Router_Handler_CaseSensitive(b *testing.B) { app := New() app.config.CaseSensitive = true registerDummyRoutes(app) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} @@ -541,7 +550,7 @@ func Benchmark_Router_Handler_CaseSensitive(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } } @@ -554,6 +563,8 @@ func Benchmark_Router_Handler_Unescape(b *testing.B) { return nil }) + appHandler := app.Handler() + c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodDelete) @@ -563,7 +574,7 @@ func Benchmark_Router_Handler_Unescape(b *testing.B) { for n := 0; n < b.N; n++ { c.URI().SetPath("/cr%C3%A9er") - app.handler(c) + appHandler(c) } } @@ -572,6 +583,7 @@ func Benchmark_Router_Handler_StrictRouting(b *testing.B) { app := New() app.config.CaseSensitive = true registerDummyRoutes(app) + appHandler := app.Handler() c := &fasthttp.RequestCtx{} @@ -581,7 +593,7 @@ func Benchmark_Router_Handler_StrictRouting(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - app.handler(c) + appHandler(c) } } @@ -589,6 +601,7 @@ func Benchmark_Router_Handler_StrictRouting(b *testing.B) { func Benchmark_Router_Github_API(b *testing.B) { app := New() registerDummyRoutes(app) + app.startupProcess() c := &fasthttp.RequestCtx{} var match bool