mirror of https://github.com/gofiber/fiber.git
Feat: Register custom methods (#2107)
* Implementing register custom methods * Return timout time to 1000 * Update app_test.go * Change update stack to add custom request method * Feat: Register custom methods #2107 * Feat: Register custom methods #2107 * update logic * optimization. * fix Co-authored-by: Rafi Muhammad <rafi.muhammad@mekari.com> Co-authored-by: RW <rene@gofiber.io> Co-authored-by: Muhammed Efe Çetin <efectn@protonmail.com>pull/2204/head
parent
581af0052d
commit
878c9549d8
34
app.go
34
app.go
|
@ -110,6 +110,8 @@ type App struct {
|
|||
latestRoute *Route
|
||||
// TLS handler
|
||||
tlsHandler *TLSHandler
|
||||
// custom method check
|
||||
customMethod bool
|
||||
// Mount fields
|
||||
mountFields *mountFields
|
||||
}
|
||||
|
@ -380,6 +382,11 @@ type Config struct {
|
|||
//
|
||||
// Optional. Default: DefaultColors
|
||||
ColorScheme Colors `json:"color_scheme"`
|
||||
|
||||
// RequestMethods provides customizibility for HTTP methods. You can add/remove methods as you wish.
|
||||
//
|
||||
// Optional. Defaukt: DefaultMethods
|
||||
RequestMethods []string
|
||||
}
|
||||
|
||||
// Static defines configuration options when defining static assets.
|
||||
|
@ -445,6 +452,19 @@ const (
|
|||
DefaultCompressedFileSuffix = ".fiber.gz"
|
||||
)
|
||||
|
||||
// HTTP methods enabled by default
|
||||
var DefaultMethods = []string{
|
||||
MethodGet,
|
||||
MethodHead,
|
||||
MethodPost,
|
||||
MethodPut,
|
||||
MethodDelete,
|
||||
MethodConnect,
|
||||
MethodOptions,
|
||||
MethodTrace,
|
||||
MethodPatch,
|
||||
}
|
||||
|
||||
// DefaultErrorHandler that process return errors from handlers
|
||||
var DefaultErrorHandler = func(c *Ctx, err error) error {
|
||||
code := StatusInternalServerError
|
||||
|
@ -469,9 +489,6 @@ var DefaultErrorHandler = func(c *Ctx, err error) error {
|
|||
func New(config ...Config) *App {
|
||||
// Create a new app
|
||||
app := &App{
|
||||
// Create router stack
|
||||
stack: make([][]*Route, len(intMethod)),
|
||||
treeStack: make([]map[string][]*Route, len(intMethod)),
|
||||
// Create Ctx pool
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
|
@ -538,12 +555,21 @@ func New(config ...Config) *App {
|
|||
if app.config.Network == "" {
|
||||
app.config.Network = NetworkTCP4
|
||||
}
|
||||
if len(app.config.RequestMethods) == 0 {
|
||||
app.config.RequestMethods = DefaultMethods
|
||||
} else {
|
||||
app.customMethod = true
|
||||
}
|
||||
|
||||
app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
|
||||
for _, ipAddress := range app.config.TrustedProxies {
|
||||
app.handleTrustedProxy(ipAddress)
|
||||
}
|
||||
|
||||
// Create router stack
|
||||
app.stack = make([][]*Route, len(app.config.RequestMethods))
|
||||
app.treeStack = make([]map[string][]*Route, len(app.config.RequestMethods))
|
||||
|
||||
// Override colors
|
||||
app.config.ColorScheme = defaultColors(app.config.ColorScheme)
|
||||
|
||||
|
@ -724,7 +750,7 @@ func (app *App) Static(prefix, root string, config ...Static) Router {
|
|||
|
||||
// All will register the handler on all HTTP methods
|
||||
func (app *App) All(path string, handlers ...Handler) Router {
|
||||
for _, method := range intMethod {
|
||||
for _, method := range app.config.RequestMethods {
|
||||
_ = app.Add(method, path, handlers...)
|
||||
}
|
||||
return app
|
||||
|
|
62
app_test.go
62
app_test.go
|
@ -435,13 +435,32 @@ func Test_App_Use_StrictRouting(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_App_Add_Method_Test(t *testing.T) {
|
||||
app := New()
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
utils.AssertEqual(t, "add: invalid http method JOHN\n", fmt.Sprintf("%v", err))
|
||||
utils.AssertEqual(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
methods := append(DefaultMethods, "JOHN")
|
||||
app := New(Config{
|
||||
RequestMethods: methods,
|
||||
})
|
||||
|
||||
app.Add("JOHN", "/doe", testEmptyHandler)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, StatusMethodNotAllowed, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, StatusBadRequest, resp.StatusCode, "Status code")
|
||||
|
||||
app.Add("JANE", "/doe", testEmptyHandler)
|
||||
}
|
||||
|
||||
// go test -run Test_App_GETOnly
|
||||
|
@ -487,7 +506,7 @@ func Test_App_Chaining(t *testing.T) {
|
|||
return c.SendStatus(202)
|
||||
})
|
||||
// check handler count for registered HEAD route
|
||||
utils.AssertEqual(t, 5, len(app.stack[methodInt(MethodHead)][0].Handlers), "app.Test(req)")
|
||||
utils.AssertEqual(t, 5, len(app.stack[app.methodInt(MethodHead)][0].Handlers), "app.Test(req)")
|
||||
|
||||
req := httptest.NewRequest(MethodPost, "/john", nil)
|
||||
|
||||
|
@ -1250,16 +1269,17 @@ func Test_App_Stack(t *testing.T) {
|
|||
app.Post("/path3", testEmptyHandler)
|
||||
|
||||
stack := app.Stack()
|
||||
utils.AssertEqual(t, 9, len(stack))
|
||||
utils.AssertEqual(t, 3, len(stack[methodInt(MethodGet)]))
|
||||
utils.AssertEqual(t, 3, len(stack[methodInt(MethodHead)]))
|
||||
utils.AssertEqual(t, 2, len(stack[methodInt(MethodPost)]))
|
||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPut)]))
|
||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodPatch)]))
|
||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodDelete)]))
|
||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodConnect)]))
|
||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodOptions)]))
|
||||
utils.AssertEqual(t, 1, len(stack[methodInt(MethodTrace)]))
|
||||
methodList := app.config.RequestMethods
|
||||
utils.AssertEqual(t, len(methodList), len(stack))
|
||||
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodGet)]))
|
||||
utils.AssertEqual(t, 3, len(stack[app.methodInt(MethodHead)]))
|
||||
utils.AssertEqual(t, 2, len(stack[app.methodInt(MethodPost)]))
|
||||
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPut)]))
|
||||
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodPatch)]))
|
||||
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodDelete)]))
|
||||
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodConnect)]))
|
||||
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodOptions)]))
|
||||
utils.AssertEqual(t, 1, len(stack[app.methodInt(MethodTrace)]))
|
||||
}
|
||||
|
||||
// go test -run Test_App_HandlersCount
|
||||
|
@ -1513,6 +1533,19 @@ func Test_App_SetTLSHandler(t *testing.T) {
|
|||
utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName)
|
||||
}
|
||||
|
||||
func Test_App_AddCustomRequestMethod(t *testing.T) {
|
||||
methods := append(DefaultMethods, "TEST")
|
||||
app := New(Config{
|
||||
RequestMethods: methods,
|
||||
})
|
||||
appMethods := app.config.RequestMethods
|
||||
|
||||
// method name is always uppercase - https://datatracker.ietf.org/doc/html/rfc7231#section-4.1
|
||||
utils.AssertEqual(t, len(app.stack), len(appMethods))
|
||||
utils.AssertEqual(t, len(app.stack), len(appMethods))
|
||||
utils.AssertEqual(t, "TEST", appMethods[len(appMethods)-1])
|
||||
}
|
||||
|
||||
func TestApp_GetRoutes(t *testing.T) {
|
||||
app := New()
|
||||
app.Use(func(c *Ctx) error {
|
||||
|
@ -1524,7 +1557,7 @@ func TestApp_GetRoutes(t *testing.T) {
|
|||
app.Delete("/delete", handler).Name("delete")
|
||||
app.Post("/post", handler).Name("post")
|
||||
routes := app.GetRoutes(false)
|
||||
utils.AssertEqual(t, 11, len(routes))
|
||||
utils.AssertEqual(t, 2+len(app.config.RequestMethods), len(routes))
|
||||
methodMap := map[string]string{"/delete": "delete", "/post": "post"}
|
||||
for _, route := range routes {
|
||||
name, ok := methodMap[route.Path]
|
||||
|
@ -1540,5 +1573,4 @@ func TestApp_GetRoutes(t *testing.T) {
|
|||
utils.AssertEqual(t, true, ok)
|
||||
utils.AssertEqual(t, name, route.Name)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
4
ctx.go
4
ctx.go
|
@ -163,7 +163,7 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
|
|||
c.pathOriginal = app.getString(fctx.URI().PathOriginal())
|
||||
// Set method
|
||||
c.method = app.getString(fctx.Request.Header.Method())
|
||||
c.methodINT = methodInt(c.method)
|
||||
c.methodINT = app.methodInt(c.method)
|
||||
// Attach *fasthttp.RequestCtx to ctx
|
||||
c.fasthttp = fctx
|
||||
// reset base uri
|
||||
|
@ -906,7 +906,7 @@ func (c *Ctx) Location(path string) {
|
|||
func (c *Ctx) Method(override ...string) string {
|
||||
if len(override) > 0 {
|
||||
method := utils.ToUpper(override[0])
|
||||
mINT := methodInt(method)
|
||||
mINT := c.app.methodInt(method)
|
||||
if mINT == -1 {
|
||||
return c.method
|
||||
}
|
||||
|
|
2
group.go
2
group.go
|
@ -133,7 +133,7 @@ func (grp *Group) Static(prefix, root string, config ...Static) Router {
|
|||
|
||||
// All will register the handler on all HTTP methods
|
||||
func (grp *Group) All(path string, handlers ...Handler) Router {
|
||||
for _, method := range intMethod {
|
||||
for _, method := range grp.app.config.RequestMethods {
|
||||
_ = grp.Add(method, path, handlers...)
|
||||
}
|
||||
return grp
|
||||
|
|
74
helpers.go
74
helpers.go
|
@ -78,8 +78,9 @@ func (app *App) quoteString(raw string) string {
|
|||
}
|
||||
|
||||
// Scan stack if other methods match the request
|
||||
func methodExist(ctx *Ctx) (exist bool) {
|
||||
for i := 0; i < len(intMethod); i++ {
|
||||
func (app *App) methodExist(ctx *Ctx) (exist bool) {
|
||||
methods := app.config.RequestMethods
|
||||
for i := 0; i < len(methods); i++ {
|
||||
// Skip original method
|
||||
if ctx.methodINT == i {
|
||||
continue
|
||||
|
@ -109,7 +110,7 @@ func methodExist(ctx *Ctx) (exist bool) {
|
|||
// We matched
|
||||
exist = true
|
||||
// Add method to Allow header
|
||||
ctx.Append(HeaderAllow, intMethod[i])
|
||||
ctx.Append(HeaderAllow, methods[i])
|
||||
// Break stack loop
|
||||
break
|
||||
}
|
||||
|
@ -331,42 +332,41 @@ var getBytesImmutable = func(s string) (b []byte) {
|
|||
}
|
||||
|
||||
// HTTP methods and their unique INTs
|
||||
func methodInt(s string) int {
|
||||
switch s {
|
||||
case MethodGet:
|
||||
return 0
|
||||
case MethodHead:
|
||||
return 1
|
||||
case MethodPost:
|
||||
return 2
|
||||
case MethodPut:
|
||||
return 3
|
||||
case MethodDelete:
|
||||
return 4
|
||||
case MethodConnect:
|
||||
return 5
|
||||
case MethodOptions:
|
||||
return 6
|
||||
case MethodTrace:
|
||||
return 7
|
||||
case MethodPatch:
|
||||
return 8
|
||||
default:
|
||||
return -1
|
||||
func (app *App) methodInt(s string) int {
|
||||
// For better performance
|
||||
if !app.customMethod {
|
||||
switch s {
|
||||
case MethodGet:
|
||||
return 0
|
||||
case MethodHead:
|
||||
return 1
|
||||
case MethodPost:
|
||||
return 2
|
||||
case MethodPut:
|
||||
return 3
|
||||
case MethodDelete:
|
||||
return 4
|
||||
case MethodConnect:
|
||||
return 5
|
||||
case MethodOptions:
|
||||
return 6
|
||||
case MethodTrace:
|
||||
return 7
|
||||
case MethodPatch:
|
||||
return 8
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP methods slice
|
||||
var intMethod = []string{
|
||||
MethodGet,
|
||||
MethodHead,
|
||||
MethodPost,
|
||||
MethodPut,
|
||||
MethodDelete,
|
||||
MethodConnect,
|
||||
MethodOptions,
|
||||
MethodTrace,
|
||||
MethodPatch,
|
||||
// For method customization
|
||||
for i, v := range app.config.RequestMethods {
|
||||
if s == v {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// HTTP methods were copied from net/http.
|
||||
|
|
12
router.go
12
router.go
|
@ -139,7 +139,7 @@ func (app *App) next(c *Ctx) (match bool, err error) {
|
|||
|
||||
// If no match, scan stack again if other methods match the request
|
||||
// Moved from app.handler because middleware may break the route chain
|
||||
if !c.matched && methodExist(c) {
|
||||
if !c.matched && app.methodExist(c) {
|
||||
err = ErrMethodNotAllowed
|
||||
}
|
||||
return
|
||||
|
@ -216,7 +216,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
|
|||
// Uppercase HTTP methods
|
||||
method = utils.ToUpper(method)
|
||||
// Check if the HTTP method is valid unless it's USE
|
||||
if method != methodUse && methodInt(method) == -1 {
|
||||
if method != methodUse && app.methodInt(method) == -1 {
|
||||
panic(fmt.Sprintf("add: invalid http method %s\n", method))
|
||||
}
|
||||
// A route requires atleast one ctx handler
|
||||
|
@ -277,7 +277,7 @@ func (app *App) register(method, pathRaw string, group *Group, handlers ...Handl
|
|||
// Middleware route matches all HTTP methods
|
||||
if isUse {
|
||||
// Add route to all HTTP methods stack
|
||||
for _, m := range intMethod {
|
||||
for _, m := range app.config.RequestMethods {
|
||||
// Create a route copy to avoid duplicates during compression
|
||||
r := route
|
||||
app.addRoute(m, &r)
|
||||
|
@ -435,7 +435,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
|
|||
}
|
||||
|
||||
// Get unique HTTP method identifier
|
||||
m := methodInt(method)
|
||||
m := app.methodInt(method)
|
||||
|
||||
// prevent identically route registration
|
||||
l := len(app.stack[m])
|
||||
|
@ -469,7 +469,7 @@ func (app *App) buildTree() *App {
|
|||
}
|
||||
|
||||
// loop all the methods and stacks and create the prefix tree
|
||||
for m := range intMethod {
|
||||
for m := range app.config.RequestMethods {
|
||||
tsMap := make(map[string][]*Route)
|
||||
for _, route := range app.stack[m] {
|
||||
treePath := ""
|
||||
|
@ -483,7 +483,7 @@ func (app *App) buildTree() *App {
|
|||
}
|
||||
|
||||
// loop the methods and tree stacks and add global stack and sort everything
|
||||
for m := range intMethod {
|
||||
for m := range app.config.RequestMethods {
|
||||
tsMap := app.treeStack[m]
|
||||
for treePart := range tsMap {
|
||||
if treePart != "" {
|
||||
|
|
Loading…
Reference in New Issue