Feat: Register custom methods (#2107)

* Implementing register custom methods

* Return timout time to 1000

* Update app_test.go

* Change update stack to add custom request method

* Feat: Register custom methods #2107

* Feat: Register custom methods #2107

* update logic

* optimization.

* fix

Co-authored-by: Rafi Muhammad <rafi.muhammad@mekari.com>
Co-authored-by: RW <rene@gofiber.io>
Co-authored-by: Muhammed Efe Çetin <efectn@protonmail.com>
pull/2204/head
Rafi Muhammad 2022-11-11 14:23:30 +07:00 committed by GitHub
parent 581af0052d
commit 878c9549d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 65 deletions

34
app.go
View File

@ -110,6 +110,8 @@ type App struct {
latestRoute *Route
// TLS handler
tlsHandler *TLSHandler
// custom method check
customMethod bool
// Mount fields
mountFields *mountFields
}
@ -380,6 +382,11 @@ type Config struct {
//
// Optional. Default: DefaultColors
ColorScheme Colors `json:"color_scheme"`
// RequestMethods provides customizibility for HTTP methods. You can add/remove methods as you wish.
//
// Optional. Defaukt: DefaultMethods
RequestMethods []string
}
// Static defines configuration options when defining static assets.
@ -445,6 +452,19 @@ const (
DefaultCompressedFileSuffix = ".fiber.gz"
)
// HTTP methods enabled by default
var DefaultMethods = []string{
MethodGet,
MethodHead,
MethodPost,
MethodPut,
MethodDelete,
MethodConnect,
MethodOptions,
MethodTrace,
MethodPatch,
}
// DefaultErrorHandler that process return errors from handlers
var DefaultErrorHandler = func(c *Ctx, err error) error {
code := StatusInternalServerError
@ -469,9 +489,6 @@ var DefaultErrorHandler = func(c *Ctx, err error) error {
func New(config ...Config) *App {
// Create a new app
app := &App{
// Create router stack
stack: make([][]*Route, len(intMethod)),
treeStack: make([]map[string][]*Route, len(intMethod)),
// Create Ctx pool
pool: sync.Pool{
New: func() interface{} {
@ -538,12 +555,21 @@ func New(config ...Config) *App {
if app.config.Network == "" {
app.config.Network = NetworkTCP4
}
if len(app.config.RequestMethods) == 0 {
app.config.RequestMethods = DefaultMethods
} else {
app.customMethod = true
}
app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
for _, ipAddress := range app.config.TrustedProxies {
app.handleTrustedProxy(ipAddress)
}
// Create router stack
app.stack = make([][]*Route, len(app.config.RequestMethods))
app.treeStack = make([]map[string][]*Route, len(app.config.RequestMethods))
// Override colors
app.config.ColorScheme = defaultColors(app.config.ColorScheme)
@ -724,7 +750,7 @@ func (app *App) Static(prefix, root string, config ...Static) Router {
// All will register the handler on all HTTP methods
func (app *App) All(path string, handlers ...Handler) Router {
for _, method := range intMethod {
for _, method := range app.config.RequestMethods {
_ = app.Add(method, path, handlers...)
}
return app

View File

@ -435,13 +435,32 @@ func Test_App_Use_StrictRouting(t *testing.T) {
}
func Test_App_Add_Method_Test(t *testing.T) {
app := New()
defer func() {
if err := recover(); err != nil {
utils.AssertEqual(t, "add: invalid http method JOHN\n", fmt.Sprintf("%v", err))
utils.AssertEqual(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err))
}
}()
methods := append(DefaultMethods, "JOHN")
app := New(Config{
RequestMethods: methods,
})
app.Add("JOHN", "/doe", testEmptyHandler)
resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, StatusMethodNotAllowed, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, StatusBadRequest, resp.StatusCode, "Status code")
app.Add("JANE", "/doe", testEmptyHandler)
}
// go test -run Test_App_GETOnly
@ -487,7 +506,7 @@ func Test_App_Chaining(t *testing.T) {
return c.SendStatus(202)
})
// check handler count for registered HEAD route
utils.AssertEqual(t, 5, len(app.stack[methodInt(MethodHead)][0].Handlers), "app.Test(req)")
utils.AssertEqual(t, 5, len(app.stack[app.methodInt(MethodHead)][0].Handlers), "app.Test(req)")
req := httptest.NewRequest(MethodPost, "/john", nil)
@ -1250,16 +1269,17 @@ func Test_App_Stack(t *testing.T) {
app.Post("/path3", testEmptyHandler)
stack := app.Stack()
utils.AssertEqual(t, 9, len(stack))
utils.AssertEqual(t, 3, len(stack[methodInt(MethodGet)]))
utils.AssertEqual(t, 3, len(stack[methodInt(MethodHead)]))
utils.AssertEqual(t, 2, len(stack[methodInt(MethodPost)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPut)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPatch)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodDelete)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodConnect)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodOptions)]))
utils.AssertEqual(t, 1, len(stack[methodInt(MethodTrace)]))
methodList := app.config.RequestMethods
utils.AssertEqual(t, len(methodList), len(stack))
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodGet)]))
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodHead)]))
utils.AssertEqual(t, 2, len(stack[app.methodInt(MethodPost)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPut)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPatch)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodDelete)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodConnect)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodOptions)]))
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodTrace)]))
}
// go test -run Test_App_HandlersCount
@ -1513,6 +1533,19 @@ func Test_App_SetTLSHandler(t *testing.T) {
utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName)
}
func Test_App_AddCustomRequestMethod(t *testing.T) {
methods := append(DefaultMethods, "TEST")
app := New(Config{
RequestMethods: methods,
})
appMethods := app.config.RequestMethods
// method name is always uppercase - https://datatracker.ietf.org/doc/html/rfc7231#section-4.1
utils.AssertEqual(t, len(app.stack), len(appMethods))
utils.AssertEqual(t, len(app.stack), len(appMethods))
utils.AssertEqual(t, "TEST", appMethods[len(appMethods)-1])
}
func TestApp_GetRoutes(t *testing.T) {
app := New()
app.Use(func(c *Ctx) error {
@ -1524,7 +1557,7 @@ func TestApp_GetRoutes(t *testing.T) {
app.Delete("/delete", handler).Name("delete")
app.Post("/post", handler).Name("post")
routes := app.GetRoutes(false)
utils.AssertEqual(t, 11, len(routes))
utils.AssertEqual(t, 2+len(app.config.RequestMethods), len(routes))
methodMap := map[string]string{"/delete": "delete", "/post": "post"}
for _, route := range routes {
name, ok := methodMap[route.Path]
@ -1540,5 +1573,4 @@ func TestApp_GetRoutes(t *testing.T) {
utils.AssertEqual(t, true, ok)
utils.AssertEqual(t, name, route.Name)
}
}

4
ctx.go
View File

@ -163,7 +163,7 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
c.pathOriginal = app.getString(fctx.URI().PathOriginal())
// Set method
c.method = app.getString(fctx.Request.Header.Method())
c.methodINT = methodInt(c.method)
c.methodINT = app.methodInt(c.method)
// Attach *fasthttp.RequestCtx to ctx
c.fasthttp = fctx
// reset base uri
@ -906,7 +906,7 @@ func (c *Ctx) Location(path string) {
func (c *Ctx) Method(override ...string) string {
if len(override) > 0 {
method := utils.ToUpper(override[0])
mINT := methodInt(method)
mINT := c.app.methodInt(method)
if mINT == -1 {
return c.method
}

View File

@ -133,7 +133,7 @@ func (grp *Group) Static(prefix, root string, config ...Static) Router {
// All will register the handler on all HTTP methods
func (grp *Group) All(path string, handlers ...Handler) Router {
for _, method := range intMethod {
for _, method := range grp.app.config.RequestMethods {
_ = grp.Add(method, path, handlers...)
}
return grp

View File

@ -78,8 +78,9 @@ func (app *App) quoteString(raw string) string {
}
// Scan stack if other methods match the request
func methodExist(ctx *Ctx) (exist bool) {
for i := 0; i < len(intMethod); i++ {
func (app *App) methodExist(ctx *Ctx) (exist bool) {
methods := app.config.RequestMethods
for i := 0; i < len(methods); i++ {
// Skip original method
if ctx.methodINT == i {
continue
@ -109,7 +110,7 @@ func methodExist(ctx *Ctx) (exist bool) {
// We matched
exist = true
// Add method to Allow header
ctx.Append(HeaderAllow, intMethod[i])
ctx.Append(HeaderAllow, methods[i])
// Break stack loop
break
}
@ -331,42 +332,41 @@ var getBytesImmutable = func(s string) (b []byte) {
}
// HTTP methods and their unique INTs
func methodInt(s string) int {
switch s {
case MethodGet:
return 0
case MethodHead:
return 1
case MethodPost:
return 2
case MethodPut:
return 3
case MethodDelete:
return 4
case MethodConnect:
return 5
case MethodOptions:
return 6
case MethodTrace:
return 7
case MethodPatch:
return 8
default:
return -1
func (app *App) methodInt(s string) int {
// For better performance
if !app.customMethod {
switch s {
case MethodGet:
return 0
case MethodHead:
return 1
case MethodPost:
return 2
case MethodPut:
return 3
case MethodDelete:
return 4
case MethodConnect:
return 5
case MethodOptions:
return 6
case MethodTrace:
return 7
case MethodPatch:
return 8
default:
return -1
}
}
}
// HTTP methods slice
var intMethod = []string{
MethodGet,
MethodHead,
MethodPost,
MethodPut,
MethodDelete,
MethodConnect,
MethodOptions,
MethodTrace,
MethodPatch,
// For method customization
for i, v := range app.config.RequestMethods {
if s == v {
return i
}
}
return -1
}
// HTTP methods were copied from net/http.

View File

@ -139,7 +139,7 @@ func (app *App) next(c *Ctx) (match bool, err error) {
// If no match, scan stack again if other methods match the request
// Moved from app.handler because middleware may break the route chain
if !c.matched && methodExist(c) {
if !c.matched && app.methodExist(c) {
err = ErrMethodNotAllowed
}
return
@ -216,7 +216,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
// Uppercase HTTP methods
method = utils.ToUpper(method)
// Check if the HTTP method is valid unless it's USE
if method != methodUse && methodInt(method) == -1 {
if method != methodUse && app.methodInt(method) == -1 {
panic(fmt.Sprintf("add: invalid http method %s\n", method))
}
// A route requires atleast one ctx handler
@ -277,7 +277,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
// Middleware route matches all HTTP methods
if isUse {
// Add route to all HTTP methods stack
for _, m := range intMethod {
for _, m := range app.config.RequestMethods {
// Create a route copy to avoid duplicates during compression
r := route
app.addRoute(m, &r)
@ -435,7 +435,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
}
// Get unique HTTP method identifier
m := methodInt(method)
m := app.methodInt(method)
// prevent identically route registration
l := len(app.stack[m])
@ -469,7 +469,7 @@ func (app *App) buildTree() *App {
}
// loop all the methods and stacks and create the prefix tree
for m := range intMethod {
for m := range app.config.RequestMethods {
tsMap := make(map[string][]*Route)
for _, route := range app.stack[m] {
treePath := ""
@ -483,7 +483,7 @@ func (app *App) buildTree() *App {
}
// loop the methods and tree stacks and add global stack and sort everything
for m := range intMethod {
for m := range app.config.RequestMethods {
tsMap := app.treeStack[m]
for treePart := range tsMap {
if treePart != "" {