diff --git a/context.go b/context.go index 7aecbbe6..b4868e45 100644 --- a/context.go +++ b/context.go @@ -141,6 +141,10 @@ func (ctx *Ctx) Set(key string, val string) { // Get : func (ctx *Ctx) Get(key string) string { + // https://en.wikipedia.org/wiki/HTTP_referer + if key == "referrer" { + key = "referer" + } return b2s(ctx.Fasthttp.Request.Header.Peek(key)) } @@ -224,7 +228,7 @@ func (ctx *Ctx) Is(ext string) bool { // Attachment : func (ctx *Ctx) Attachment(args ...interface{}) { if len(args) == 1 { - filename := args[0].(string) + filename := filepath.Base(args[0].(string)) ctx.Type(filepath.Ext(filename)) ctx.Set("Content-Disposition", `attachment; filename="`+filename+`"`) return @@ -232,9 +236,25 @@ func (ctx *Ctx) Attachment(args ...interface{}) { ctx.Set("Content-Disposition", "attachment") } +// Download : +func (ctx *Ctx) Download(args ...interface{}) { + var file string + var filename string + if len(args) == 1 { + file = args[0].(string) + filename = filepath.Base(file) + } + if len(args) == 2 { + file = args[0].(string) + filename = args[1].(string) + } + ctx.Set("Content-Disposition", "attachment; filename="+filename) + ctx.SendFile(file) +} + // SendFile : -func (ctx *Ctx) SendFile(path string) { - fasthttp.ServeFile(ctx.Fasthttp, path) +func (ctx *Ctx) SendFile(file string) { + fasthttp.ServeFile(ctx.Fasthttp, file) //ctx.Type(filepath.Ext(path)) //ctx.Fasthttp.SendFile(path) } diff --git a/helpers.go b/helpers.go index f85297b3..5187a989 100644 --- a/helpers.go +++ b/helpers.go @@ -1,6 +1,8 @@ package fiber import ( + "os" + "path/filepath" "reflect" "regexp" "strings" @@ -51,6 +53,20 @@ func getRegex(path string) (*regexp.Regexp, error) { return regex, err } +func walk(root string) (files []string, dir bool, err error) { + err = filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if !info.IsDir() { + if !strings.Contains(path, ".fasthttp.gz") { + files = append(files, path) + } + } else { + dir = true + } + return nil + }) + return files, dir, err +} + // Credits to @savsgio // https://github.com/savsgio/gotils/blob/master/conv.go diff --git a/router.go b/router.go index 994ea906..e603a67a 100644 --- a/router.go +++ b/router.go @@ -2,7 +2,9 @@ package fiber import ( "fmt" + "path/filepath" "regexp" + "strings" "time" "github.com/valyala/fasthttp" @@ -10,8 +12,8 @@ import ( type route struct { method string + any bool path string - anyPath bool regex *regexp.Regexp params []string handler func(*Ctx) @@ -53,7 +55,6 @@ type Fiber struct { // New : func New() *Fiber { return &Fiber{ - methods: []string{"GET", "PUT", "POST", "DELETE", "HEAD", "PATCH", "OPTIONS", "TRACE", "CONNECT"}, Settings: &Settings{ TLSEnable: false, CertKey: "", @@ -81,9 +82,9 @@ func New() *Fiber { } } -// Get : -func (r *Fiber) Get(args ...interface{}) { - r.register("GET", args...) +// Connect : +func (r *Fiber) Connect(args ...interface{}) { + r.register("CONNECT", args...) } // Put : @@ -121,68 +122,90 @@ func (r *Fiber) Trace(args ...interface{}) { r.register("TRACE", args...) } -// Connect : -func (r *Fiber) Connect(args ...interface{}) { - r.register("CONNECT", args...) +// Get : +func (r *Fiber) Get(args ...interface{}) { + r.register("GET", args...) } // All : func (r *Fiber) All(args ...interface{}) { r.register("*", args...) - // for _, method := range r.methods { - // r.register(method, args...) - // } } -// Use : -func (r *Fiber) Use(args ...interface{}) { - r.register("*", args...) - // for _, method := range r.methods { - // r.register(method, args...) - // } -} - -// register : func (r *Fiber) register(method string, args ...interface{}) { - // Pre-set variables for interface assertion - var ok bool + // Options var path string + var static string var handler func(*Ctx) - // Register only handler: app.Get(handler) + // app.Get(handler) if len(args) == 1 { - // Convert interface to func(*Context) - handler, ok = args[0].(func(*Ctx)) - if !ok { - panic("Invalid handler must be func(*express.Context)") + switch arg := args[0].(type) { + case string: + static = arg + case func(*Ctx): + handler = arg } } - // Register path and handler: app.Get(path, handler) + // app.Get(path, handler) if len(args) == 2 { - // Convert interface to path string - path, ok = args[0].(string) - if !ok { - panic("Invalid path") - } - // Panic if first char does not begins with / or * + path = args[0].(string) if path[0] != '/' && path[0] != '*' { panic("Invalid path, must begin with slash '/' or wildcard '*'") } - // Convert interface to func(*Context) - handler, ok = args[1].(func(*Ctx)) - if !ok { - panic("Invalid handler, must be func(*express.Context)") + switch arg := args[1].(type) { + case string: + static = arg + case func(*Ctx): + handler = arg } } - // If its a simple wildcard ( aka match anything ) + // Is this a static file handler? + if static != "" { + // static file route!! + r.registerStatic(method, path, static) + } else if handler != nil { + // function route!! + r.registerHandler(method, path, handler) + } else { + panic("Every route needs to contain either a dir/file path or callback function") + } +} +func (r *Fiber) registerStatic(method, prefix, root string) { + var any bool + if prefix == "*" || prefix == "/*" { + any = true + } + if prefix == "" { + prefix = "/" + } + files, _, err := walk(root) + if err != nil { + panic(err) + } + mount := filepath.Clean(root) + for _, file := range files { + path := filepath.Join(prefix, strings.Replace(file, mount, "", 1)) + filePath := file + if filepath.Base(filePath) == "index.html" { + r.routes = append(r.routes, &route{method, any, prefix, nil, nil, func(c *Ctx) { + c.SendFile(filePath) + }}) + } + r.routes = append(r.routes, &route{method, any, path, nil, nil, func(c *Ctx) { + c.SendFile(filePath) + }}) + } +} +func (r *Fiber) registerHandler(method, path string, handler func(*Ctx)) { if path == "" || path == "*" || path == "/*" { - r.routes = append(r.routes, &route{method, path, true, nil, nil, handler}) + r.routes = append(r.routes, &route{method, true, path, nil, nil, handler}) return } // Get params from path params := getParams(path) // If path has no params, we dont need regex if len(params) == 0 { - r.routes = append(r.routes, &route{method, path, false, nil, nil, handler}) + r.routes = append(r.routes, &route{method, false, path, nil, nil, handler}) return } @@ -191,7 +214,7 @@ func (r *Fiber) register(method string, args ...interface{}) { if err != nil { panic("Invalid url pattern: " + path) } - r.routes = append(r.routes, &route{method, path, false, regex, params, handler}) + r.routes = append(r.routes, &route{method, false, path, regex, params, handler}) } // handler : @@ -208,9 +231,9 @@ func (r *Fiber) handler(fctx *fasthttp.RequestCtx) { continue } // First check if we match a static path or wildcard - if route.anyPath || (route.path == path && route.params == nil) { + if route.any || (route.path == path && route.params == nil) { // If * always set the path to the wildcard parameter - if route.anyPath { + if route.any { ctx.params = &[]string{"*"} ctx.values = []string{path} } @@ -252,7 +275,16 @@ func (r *Fiber) handler(fctx *fasthttp.RequestCtx) { } // Listen : -func (r *Fiber) Listen(port int) { +func (r *Fiber) Listen(args ...interface{}) { + var port int + var addr string + if len(args) == 1 { + port = args[0].(int) + } + if len(args) == 2 { + addr = args[0].(string) + port = args[1].(int) + } // Disable server header if server name is not given if r.Settings.Name != "" { r.Settings.NoDefaultServerHeader = false @@ -282,12 +314,12 @@ func (r *Fiber) Listen(port int) { KeepHijackedConns: r.Settings.KeepHijackedConns, } if r.Settings.TLSEnable { - if err := server.ListenAndServeTLS(fmt.Sprintf(":%v", port), r.Settings.CertFile, r.Settings.CertKey); err != nil { + if err := server.ListenAndServeTLS(fmt.Sprintf("%s:%v", addr, port), r.Settings.CertFile, r.Settings.CertKey); err != nil { panic(err) } return } - if err := server.ListenAndServe(fmt.Sprintf(":%v", port)); err != nil { + if err := server.ListenAndServe(fmt.Sprintf("%s:%v", addr, port)); err != nil { panic(err) } }