mirror of https://github.com/gofiber/fiber.git
Feat: Register custom methods (#2107)
* Implementing register custom methods * Return timout time to 1000 * Update app_test.go * Change update stack to add custom request method * Feat: Register custom methods #2107 * Feat: Register custom methods #2107 * update logic * optimization. * fix Co-authored-by: Rafi Muhammad <rafi.muhammad@mekari.com> Co-authored-by: RW <rene@gofiber.io> Co-authored-by: Muhammed Efe Çetin <efectn@protonmail.com>pull/2204/head
parent
581af0052d
commit
878c9549d8
34
app.go
34
app.go
|
@ -110,6 +110,8 @@ type App struct {
|
||||||
latestRoute *Route
|
latestRoute *Route
|
||||||
// TLS handler
|
// TLS handler
|
||||||
tlsHandler *TLSHandler
|
tlsHandler *TLSHandler
|
||||||
|
// custom method check
|
||||||
|
customMethod bool
|
||||||
// Mount fields
|
// Mount fields
|
||||||
mountFields *mountFields
|
mountFields *mountFields
|
||||||
}
|
}
|
||||||
|
@ -380,6 +382,11 @@ type Config struct {
|
||||||
//
|
//
|
||||||
// Optional. Default: DefaultColors
|
// Optional. Default: DefaultColors
|
||||||
ColorScheme Colors `json:"color_scheme"`
|
ColorScheme Colors `json:"color_scheme"`
|
||||||
|
|
||||||
|
// RequestMethods provides customizibility for HTTP methods. You can add/remove methods as you wish.
|
||||||
|
//
|
||||||
|
// Optional. Defaukt: DefaultMethods
|
||||||
|
RequestMethods []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Static defines configuration options when defining static assets.
|
// Static defines configuration options when defining static assets.
|
||||||
|
@ -445,6 +452,19 @@ const (
|
||||||
DefaultCompressedFileSuffix = ".fiber.gz"
|
DefaultCompressedFileSuffix = ".fiber.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HTTP methods enabled by default
|
||||||
|
var DefaultMethods = []string{
|
||||||
|
MethodGet,
|
||||||
|
MethodHead,
|
||||||
|
MethodPost,
|
||||||
|
MethodPut,
|
||||||
|
MethodDelete,
|
||||||
|
MethodConnect,
|
||||||
|
MethodOptions,
|
||||||
|
MethodTrace,
|
||||||
|
MethodPatch,
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultErrorHandler that process return errors from handlers
|
// DefaultErrorHandler that process return errors from handlers
|
||||||
var DefaultErrorHandler = func(c *Ctx, err error) error {
|
var DefaultErrorHandler = func(c *Ctx, err error) error {
|
||||||
code := StatusInternalServerError
|
code := StatusInternalServerError
|
||||||
|
@ -469,9 +489,6 @@ var DefaultErrorHandler = func(c *Ctx, err error) error {
|
||||||
func New(config ...Config) *App {
|
func New(config ...Config) *App {
|
||||||
// Create a new app
|
// Create a new app
|
||||||
app := &App{
|
app := &App{
|
||||||
// Create router stack
|
|
||||||
stack: make([][]*Route, len(intMethod)),
|
|
||||||
treeStack: make([]map[string][]*Route, len(intMethod)),
|
|
||||||
// Create Ctx pool
|
// Create Ctx pool
|
||||||
pool: sync.Pool{
|
pool: sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
|
@ -538,12 +555,21 @@ func New(config ...Config) *App {
|
||||||
if app.config.Network == "" {
|
if app.config.Network == "" {
|
||||||
app.config.Network = NetworkTCP4
|
app.config.Network = NetworkTCP4
|
||||||
}
|
}
|
||||||
|
if len(app.config.RequestMethods) == 0 {
|
||||||
|
app.config.RequestMethods = DefaultMethods
|
||||||
|
} else {
|
||||||
|
app.customMethod = true
|
||||||
|
}
|
||||||
|
|
||||||
app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
|
app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
|
||||||
for _, ipAddress := range app.config.TrustedProxies {
|
for _, ipAddress := range app.config.TrustedProxies {
|
||||||
app.handleTrustedProxy(ipAddress)
|
app.handleTrustedProxy(ipAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create router stack
|
||||||
|
app.stack = make([][]*Route, len(app.config.RequestMethods))
|
||||||
|
app.treeStack = make([]map[string][]*Route, len(app.config.RequestMethods))
|
||||||
|
|
||||||
// Override colors
|
// Override colors
|
||||||
app.config.ColorScheme = defaultColors(app.config.ColorScheme)
|
app.config.ColorScheme = defaultColors(app.config.ColorScheme)
|
||||||
|
|
||||||
|
@ -724,7 +750,7 @@ func (app *App) Static(prefix, root string, config ...Static) Router {
|
||||||
|
|
||||||
// All will register the handler on all HTTP methods
|
// All will register the handler on all HTTP methods
|
||||||
func (app *App) All(path string, handlers ...Handler) Router {
|
func (app *App) All(path string, handlers ...Handler) Router {
|
||||||
for _, method := range intMethod {
|
for _, method := range app.config.RequestMethods {
|
||||||
_ = app.Add(method, path, handlers...)
|
_ = app.Add(method, path, handlers...)
|
||||||
}
|
}
|
||||||
return app
|
return app
|
||||||
|
|
62
app_test.go
62
app_test.go
|
@ -435,13 +435,32 @@ func Test_App_Use_StrictRouting(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_App_Add_Method_Test(t *testing.T) {
|
func Test_App_Add_Method_Test(t *testing.T) {
|
||||||
app := New()
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
utils.AssertEqual(t, "add: invalid http method JOHN\n", fmt.Sprintf("%v", err))
|
utils.AssertEqual(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
methods := append(DefaultMethods, "JOHN")
|
||||||
|
app := New(Config{
|
||||||
|
RequestMethods: methods,
|
||||||
|
})
|
||||||
|
|
||||||
app.Add("JOHN", "/doe", testEmptyHandler)
|
app.Add("JOHN", "/doe", testEmptyHandler)
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil))
|
||||||
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
|
utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil))
|
||||||
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
|
utils.AssertEqual(t, StatusMethodNotAllowed, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil))
|
||||||
|
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||||
|
utils.AssertEqual(t, StatusBadRequest, resp.StatusCode, "Status code")
|
||||||
|
|
||||||
|
app.Add("JANE", "/doe", testEmptyHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_App_GETOnly
|
// go test -run Test_App_GETOnly
|
||||||
|
@ -487,7 +506,7 @@ func Test_App_Chaining(t *testing.T) {
|
||||||
return c.SendStatus(202)
|
return c.SendStatus(202)
|
||||||
})
|
})
|
||||||
// check handler count for registered HEAD route
|
// check handler count for registered HEAD route
|
||||||
utils.AssertEqual(t, 5, len(app.stack[methodInt(MethodHead)][0].Handlers), "app.Test(req)")
|
utils.AssertEqual(t, 5, len(app.stack[app.methodInt(MethodHead)][0].Handlers), "app.Test(req)")
|
||||||
|
|
||||||
req := httptest.NewRequest(MethodPost, "/john", nil)
|
req := httptest.NewRequest(MethodPost, "/john", nil)
|
||||||
|
|
||||||
|
@ -1250,16 +1269,17 @@ func Test_App_Stack(t *testing.T) {
|
||||||
app.Post("/path3", testEmptyHandler)
|
app.Post("/path3", testEmptyHandler)
|
||||||
|
|
||||||
stack := app.Stack()
|
stack := app.Stack()
|
||||||
utils.AssertEqual(t, 9, len(stack))
|
methodList := app.config.RequestMethods
|
||||||
utils.AssertEqual(t, 3, len(stack[methodInt(MethodGet)]))
|
utils.AssertEqual(t, len(methodList), len(stack))
|
||||||
utils.AssertEqual(t, 3, len(stack[methodInt(MethodHead)]))
|
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodGet)]))
|
||||||
utils.AssertEqual(t, 2, len(stack[methodInt(MethodPost)]))
|
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodHead)]))
|
||||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPut)]))
|
utils.AssertEqual(t, 2, len(stack[app.methodInt(MethodPost)]))
|
||||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPatch)]))
|
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPut)]))
|
||||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodDelete)]))
|
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPatch)]))
|
||||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodConnect)]))
|
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodDelete)]))
|
||||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodOptions)]))
|
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodConnect)]))
|
||||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodTrace)]))
|
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodOptions)]))
|
||||||
|
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodTrace)]))
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -run Test_App_HandlersCount
|
// go test -run Test_App_HandlersCount
|
||||||
|
@ -1513,6 +1533,19 @@ func Test_App_SetTLSHandler(t *testing.T) {
|
||||||
utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName)
|
utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_App_AddCustomRequestMethod(t *testing.T) {
|
||||||
|
methods := append(DefaultMethods, "TEST")
|
||||||
|
app := New(Config{
|
||||||
|
RequestMethods: methods,
|
||||||
|
})
|
||||||
|
appMethods := app.config.RequestMethods
|
||||||
|
|
||||||
|
// method name is always uppercase - https://datatracker.ietf.org/doc/html/rfc7231#section-4.1
|
||||||
|
utils.AssertEqual(t, len(app.stack), len(appMethods))
|
||||||
|
utils.AssertEqual(t, len(app.stack), len(appMethods))
|
||||||
|
utils.AssertEqual(t, "TEST", appMethods[len(appMethods)-1])
|
||||||
|
}
|
||||||
|
|
||||||
func TestApp_GetRoutes(t *testing.T) {
|
func TestApp_GetRoutes(t *testing.T) {
|
||||||
app := New()
|
app := New()
|
||||||
app.Use(func(c *Ctx) error {
|
app.Use(func(c *Ctx) error {
|
||||||
|
@ -1524,7 +1557,7 @@ func TestApp_GetRoutes(t *testing.T) {
|
||||||
app.Delete("/delete", handler).Name("delete")
|
app.Delete("/delete", handler).Name("delete")
|
||||||
app.Post("/post", handler).Name("post")
|
app.Post("/post", handler).Name("post")
|
||||||
routes := app.GetRoutes(false)
|
routes := app.GetRoutes(false)
|
||||||
utils.AssertEqual(t, 11, len(routes))
|
utils.AssertEqual(t, 2+len(app.config.RequestMethods), len(routes))
|
||||||
methodMap := map[string]string{"/delete": "delete", "/post": "post"}
|
methodMap := map[string]string{"/delete": "delete", "/post": "post"}
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
name, ok := methodMap[route.Path]
|
name, ok := methodMap[route.Path]
|
||||||
|
@ -1540,5 +1573,4 @@ func TestApp_GetRoutes(t *testing.T) {
|
||||||
utils.AssertEqual(t, true, ok)
|
utils.AssertEqual(t, true, ok)
|
||||||
utils.AssertEqual(t, name, route.Name)
|
utils.AssertEqual(t, name, route.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
4
ctx.go
4
ctx.go
|
@ -163,7 +163,7 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
|
||||||
c.pathOriginal = app.getString(fctx.URI().PathOriginal())
|
c.pathOriginal = app.getString(fctx.URI().PathOriginal())
|
||||||
// Set method
|
// Set method
|
||||||
c.method = app.getString(fctx.Request.Header.Method())
|
c.method = app.getString(fctx.Request.Header.Method())
|
||||||
c.methodINT = methodInt(c.method)
|
c.methodINT = app.methodInt(c.method)
|
||||||
// Attach *fasthttp.RequestCtx to ctx
|
// Attach *fasthttp.RequestCtx to ctx
|
||||||
c.fasthttp = fctx
|
c.fasthttp = fctx
|
||||||
// reset base uri
|
// reset base uri
|
||||||
|
@ -906,7 +906,7 @@ func (c *Ctx) Location(path string) {
|
||||||
func (c *Ctx) Method(override ...string) string {
|
func (c *Ctx) Method(override ...string) string {
|
||||||
if len(override) > 0 {
|
if len(override) > 0 {
|
||||||
method := utils.ToUpper(override[0])
|
method := utils.ToUpper(override[0])
|
||||||
mINT := methodInt(method)
|
mINT := c.app.methodInt(method)
|
||||||
if mINT == -1 {
|
if mINT == -1 {
|
||||||
return c.method
|
return c.method
|
||||||
}
|
}
|
||||||
|
|
2
group.go
2
group.go
|
@ -133,7 +133,7 @@ func (grp *Group) Static(prefix, root string, config ...Static) Router {
|
||||||
|
|
||||||
// All will register the handler on all HTTP methods
|
// All will register the handler on all HTTP methods
|
||||||
func (grp *Group) All(path string, handlers ...Handler) Router {
|
func (grp *Group) All(path string, handlers ...Handler) Router {
|
||||||
for _, method := range intMethod {
|
for _, method := range grp.app.config.RequestMethods {
|
||||||
_ = grp.Add(method, path, handlers...)
|
_ = grp.Add(method, path, handlers...)
|
||||||
}
|
}
|
||||||
return grp
|
return grp
|
||||||
|
|
30
helpers.go
30
helpers.go
|
@ -78,8 +78,9 @@ func (app *App) quoteString(raw string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan stack if other methods match the request
|
// Scan stack if other methods match the request
|
||||||
func methodExist(ctx *Ctx) (exist bool) {
|
func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
||||||
for i := 0; i < len(intMethod); i++ {
|
methods := app.config.RequestMethods
|
||||||
|
for i := 0; i < len(methods); i++ {
|
||||||
// Skip original method
|
// Skip original method
|
||||||
if ctx.methodINT == i {
|
if ctx.methodINT == i {
|
||||||
continue
|
continue
|
||||||
|
@ -109,7 +110,7 @@ func methodExist(ctx *Ctx) (exist bool) {
|
||||||
// We matched
|
// We matched
|
||||||
exist = true
|
exist = true
|
||||||
// Add method to Allow header
|
// Add method to Allow header
|
||||||
ctx.Append(HeaderAllow, intMethod[i])
|
ctx.Append(HeaderAllow, methods[i])
|
||||||
// Break stack loop
|
// Break stack loop
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -331,7 +332,9 @@ var getBytesImmutable = func(s string) (b []byte) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP methods and their unique INTs
|
// HTTP methods and their unique INTs
|
||||||
func methodInt(s string) int {
|
func (app *App) methodInt(s string) int {
|
||||||
|
// For better performance
|
||||||
|
if !app.customMethod {
|
||||||
switch s {
|
switch s {
|
||||||
case MethodGet:
|
case MethodGet:
|
||||||
return 0
|
return 0
|
||||||
|
@ -356,17 +359,14 @@ func methodInt(s string) int {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP methods slice
|
// For method customization
|
||||||
var intMethod = []string{
|
for i, v := range app.config.RequestMethods {
|
||||||
MethodGet,
|
if s == v {
|
||||||
MethodHead,
|
return i
|
||||||
MethodPost,
|
}
|
||||||
MethodPut,
|
}
|
||||||
MethodDelete,
|
|
||||||
MethodConnect,
|
return -1
|
||||||
MethodOptions,
|
|
||||||
MethodTrace,
|
|
||||||
MethodPatch,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP methods were copied from net/http.
|
// HTTP methods were copied from net/http.
|
||||||
|
|
12
router.go
12
router.go
|
@ -139,7 +139,7 @@ func (app *App) next(c *Ctx) (match bool, err error) {
|
||||||
|
|
||||||
// If no match, scan stack again if other methods match the request
|
// If no match, scan stack again if other methods match the request
|
||||||
// Moved from app.handler because middleware may break the route chain
|
// Moved from app.handler because middleware may break the route chain
|
||||||
if !c.matched && methodExist(c) {
|
if !c.matched && app.methodExist(c) {
|
||||||
err = ErrMethodNotAllowed
|
err = ErrMethodNotAllowed
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -216,7 +216,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
|
||||||
// Uppercase HTTP methods
|
// Uppercase HTTP methods
|
||||||
method = utils.ToUpper(method)
|
method = utils.ToUpper(method)
|
||||||
// Check if the HTTP method is valid unless it's USE
|
// Check if the HTTP method is valid unless it's USE
|
||||||
if method != methodUse && methodInt(method) == -1 {
|
if method != methodUse && app.methodInt(method) == -1 {
|
||||||
panic(fmt.Sprintf("add: invalid http method %s\n", method))
|
panic(fmt.Sprintf("add: invalid http method %s\n", method))
|
||||||
}
|
}
|
||||||
// A route requires atleast one ctx handler
|
// A route requires atleast one ctx handler
|
||||||
|
@ -277,7 +277,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
|
||||||
// Middleware route matches all HTTP methods
|
// Middleware route matches all HTTP methods
|
||||||
if isUse {
|
if isUse {
|
||||||
// Add route to all HTTP methods stack
|
// Add route to all HTTP methods stack
|
||||||
for _, m := range intMethod {
|
for _, m := range app.config.RequestMethods {
|
||||||
// Create a route copy to avoid duplicates during compression
|
// Create a route copy to avoid duplicates during compression
|
||||||
r := route
|
r := route
|
||||||
app.addRoute(m, &r)
|
app.addRoute(m, &r)
|
||||||
|
@ -435,7 +435,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get unique HTTP method identifier
|
// Get unique HTTP method identifier
|
||||||
m := methodInt(method)
|
m := app.methodInt(method)
|
||||||
|
|
||||||
// prevent identically route registration
|
// prevent identically route registration
|
||||||
l := len(app.stack[m])
|
l := len(app.stack[m])
|
||||||
|
@ -469,7 +469,7 @@ func (app *App) buildTree() *App {
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop all the methods and stacks and create the prefix tree
|
// loop all the methods and stacks and create the prefix tree
|
||||||
for m := range intMethod {
|
for m := range app.config.RequestMethods {
|
||||||
tsMap := make(map[string][]*Route)
|
tsMap := make(map[string][]*Route)
|
||||||
for _, route := range app.stack[m] {
|
for _, route := range app.stack[m] {
|
||||||
treePath := ""
|
treePath := ""
|
||||||
|
@ -483,7 +483,7 @@ func (app *App) buildTree() *App {
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop the methods and tree stacks and add global stack and sort everything
|
// loop the methods and tree stacks and add global stack and sort everything
|
||||||
for m := range intMethod {
|
for m := range app.config.RequestMethods {
|
||||||
tsMap := app.treeStack[m]
|
tsMap := app.treeStack[m]
|
||||||
for treePart := range tsMap {
|
for treePart := range tsMap {
|
||||||
if treePart != "" {
|
if treePart != "" {
|
||||||
|
|
Loading…
Reference in New Issue