diff --git a/middleware/adaptor/README.md b/middleware/adaptor/README.md new file mode 100644 index 00000000..363ad6b1 --- /dev/null +++ b/middleware/adaptor/README.md @@ -0,0 +1,141 @@ +# Adaptor + +![Release](https://img.shields.io/github/release/gofiber/adaptor.svg) +[![Discord](https://img.shields.io/badge/discord-join%20channel-7289DA)](https://gofiber.io/discord) +![Test](https://github.com/gofiber/adaptor/workflows/Test/badge.svg) +![Security](https://github.com/gofiber/adaptor/workflows/Security/badge.svg) +![Linter](https://github.com/gofiber/adaptor/workflows/Linter/badge.svg) + +Converter for net/http handlers to/from Fiber request handlers, special thanks to [@arsmn](https://github.com/arsmn)! + +### Install +``` +go get -u github.com/gofiber/fiber/v2 +go get -u github.com/gofiber/adaptor/v2 +``` + +### Functions +| Name | Signature | Description +| :--- | :--- | :--- +| HTTPHandler | `HTTPHandler(h http.Handler) fiber.Handler` | http.Handler -> fiber.Handler +| HTTPHandlerFunc | `HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler` | http.HandlerFunc -> fiber.Handler +| HTTPMiddleware | `HTTPHandlerFunc(mw func(http.Handler) http.Handler) fiber.Handler` | func(http.Handler) http.Handler -> fiber.Handler +| FiberHandler | `FiberHandler(h fiber.Handler) http.Handler` | fiber.Handler -> http.Handler +| FiberHandlerFunc | `FiberHandlerFunc(h fiber.Handler) http.HandlerFunc` | fiber.Handler -> http.HandlerFunc +| FiberApp | `FiberApp(app *fiber.App) http.HandlerFunc` | Fiber app -> http.HandlerFunc + +### net/http to Fiber +```go +package main + +import ( + "fmt" + "net/http" + + "github.com/gofiber/adaptor/v2" + "github.com/gofiber/fiber/v2" +) + +func main() { + // New fiber app + app := fiber.New() + + // http.Handler -> fiber.Handler + app.Get("/", adaptor.HTTPHandler(handler(greet))) + + // http.HandlerFunc -> fiber.Handler + app.Get("/func", adaptor.HTTPHandlerFunc(greet)) + + // Listen on port 3000 + app.Listen(":3000") +} + +func handler(f http.HandlerFunc) http.Handler { + return http.HandlerFunc(f) +} + +func greet(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello World!") +} +``` + +### net/http middleware to Fiber +```go +package main + +import ( + "log" + "net/http" + + "github.com/gofiber/adaptor/v2" + "github.com/gofiber/fiber/v2" +) + +func main() { + // New fiber app + app := fiber.New() + + // http middleware -> fiber.Handler + app.Use(adaptor.HTTPMiddleware(logMiddleware)) + + // Listen on port 3000 + app.Listen(":3000") +} + +func logMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("log middleware") + next.ServeHTTP(w, r) + }) +} +``` + +### Fiber Handler to net/http +```go +package main + +import ( + "net/http" + + "github.com/gofiber/adaptor/v2" + "github.com/gofiber/fiber/v2" +) + +func main() { + // fiber.Handler -> http.Handler + http.Handle("/", adaptor.FiberHandler(greet)) + + // fiber.Handler -> http.HandlerFunc + http.HandleFunc("/func", adaptor.FiberHandlerFunc(greet)) + + // Listen on port 3000 + http.ListenAndServe(":3000", nil) +} + +func greet(c *fiber.Ctx) error { + return c.SendString("Hello World!") +} +``` + +### Fiber App to net/http +```go +package main + +import ( + "github.com/gofiber/adaptor/v2" + "github.com/gofiber/fiber/v2" + "net/http" +) +func main() { + app := fiber.New() + + app.Get("/greet", greet) + + // Listen on port 3000 + http.ListenAndServe(":3000", adaptor.FiberApp(app)) +} + +func greet(c *fiber.Ctx) error { + return c.SendString("Hello World!") +} +``` diff --git a/middleware/adaptor/adopter.go b/middleware/adaptor/adopter.go new file mode 100644 index 00000000..eaacab03 --- /dev/null +++ b/middleware/adaptor/adopter.go @@ -0,0 +1,127 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package adaptor + +import ( + "io/ioutil" + "net" + "net/http" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" +) + +// HTTPHandlerFunc wraps net/http handler func to fiber handler +func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler { + return HTTPHandler(h) +} + +// HTTPHandler wraps net/http handler to fiber handler +func HTTPHandler(h http.Handler) fiber.Handler { + return func(c *fiber.Ctx) error { + handler := fasthttpadaptor.NewFastHTTPHandler(h) + handler(c.Context()) + return nil + } +} + +// HTTPMiddleware wraps net/http middleware to fiber middleware +func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler { + return func(c *fiber.Ctx) error { + var next bool + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next = true + // Convert again in case request may modify by middleware + c.Request().Header.SetMethod(r.Method) + c.Request().SetRequestURI(r.RequestURI) + c.Request().SetHost(r.Host) + for key, val := range r.Header { + for _, v := range val { + c.Request().Header.Set(key, v) + } + } + }) + _ = HTTPHandler(mw(nextHandler))(c) + if next { + return c.Next() + } + return nil + } +} + +// FiberHandler wraps fiber handler to net/http handler +func FiberHandler(h fiber.Handler) http.Handler { + return FiberHandlerFunc(h) +} + +// FiberHandlerFunc wraps fiber handler to net/http handler func +func FiberHandlerFunc(h fiber.Handler) http.HandlerFunc { + return handlerFunc(fiber.New(), h) +} + +// FiberApp wraps fiber app to net/http handler func +func FiberApp(app *fiber.App) http.HandlerFunc { + return handlerFunc(app) +} + +func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // New fasthttp request + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + // Convert net/http -> fasthttp request + if r.Body != nil { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError) + return + } + req.Header.SetContentLength(len(body)) + _, _ = req.BodyWriter().Write(body) + } + req.Header.SetMethod(r.Method) + req.SetRequestURI(r.RequestURI) + req.SetHost(r.Host) + for key, val := range r.Header { + for _, v := range val { + req.Header.Set(key, v) + } + } + if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil && err.(*net.AddrError).Err == "missing port in address" { + r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") + } + remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if err != nil { + http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError) + return + } + + // New fasthttp Ctx + var fctx fasthttp.RequestCtx + fctx.Init(req, remoteAddr, nil) + if len(h) > 0 { + // New fiber Ctx + ctx := app.AcquireCtx(&fctx) + defer app.ReleaseCtx(ctx) + // Execute fiber Ctx + err := h[0](ctx) + if err != nil { + _ = app.Config().ErrorHandler(ctx, err) + } + } else { + // Execute fasthttp Ctx though app.Handler + app.Handler()(&fctx) + } + + // Convert fasthttp Ctx > net/http + fctx.Response.Header.VisitAll(func(k, v []byte) { + w.Header().Add(string(k), string(v)) + }) + w.WriteHeader(fctx.Response.StatusCode()) + _, _ = w.Write(fctx.Response.Body()) + } +} diff --git a/middleware/adaptor/adopter_test.go b/middleware/adaptor/adopter_test.go new file mode 100644 index 00000000..6b08a117 --- /dev/null +++ b/middleware/adaptor/adopter_test.go @@ -0,0 +1,413 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package adaptor + +import ( + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "reflect" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/valyala/fasthttp" +) + +func Test_HTTPHandler(t *testing.T) { + expectedMethod := fiber.MethodPost + expectedProto := "HTTP/1.1" + expectedProtoMajor := 1 + expectedProtoMinor := 1 + expectedRequestURI := "/foo/bar?baz=123" + expectedBody := "body 123 foo bar baz" + expectedContentLength := len(expectedBody) + expectedHost := "foobar.com" + expectedRemoteAddr := "1.2.3.4:6789" + expectedHeader := map[string]string{ + "Foo-Bar": "baz", + "Abc": "defg", + "XXX-Remote-Addr": "123.43.4543.345", + } + expectedURL, err := url.ParseRequestURI(expectedRequestURI) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + expectedContextKey := "contextKey" + expectedContextValue := "contextValue" + + callsCount := 0 + nethttpH := func(w http.ResponseWriter, r *http.Request) { + callsCount++ + if r.Method != expectedMethod { + t.Fatalf("unexpected method %q. Expecting %q", r.Method, expectedMethod) + } + if r.Proto != expectedProto { + t.Fatalf("unexpected proto %q. Expecting %q", r.Proto, expectedProto) + } + if r.ProtoMajor != expectedProtoMajor { + t.Fatalf("unexpected protoMajor %d. Expecting %d", r.ProtoMajor, expectedProtoMajor) + } + if r.ProtoMinor != expectedProtoMinor { + t.Fatalf("unexpected protoMinor %d. Expecting %d", r.ProtoMinor, expectedProtoMinor) + } + if r.RequestURI != expectedRequestURI { + t.Fatalf("unexpected requestURI %q. Expecting %q", r.RequestURI, expectedRequestURI) + } + if r.ContentLength != int64(expectedContentLength) { + t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength) + } + if len(r.TransferEncoding) != 0 { + t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding) + } + if r.Host != expectedHost { + t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost) + } + if r.RemoteAddr != expectedRemoteAddr { + t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr) + } + body, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + t.Fatalf("unexpected error when reading request body: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + if !reflect.DeepEqual(r.URL, expectedURL) { + t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL) + } + if r.Context().Value(expectedContextKey) != expectedContextValue { + t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue) + } + + for k, expectedV := range expectedHeader { + v := r.Header.Get(k) + if v != expectedV { + t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV) + } + } + + w.Header().Set("Header1", "value1") + w.Header().Set("Header2", "value2") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "request body is %q", body) + } + fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH)) + fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue) + + var fctx fasthttp.RequestCtx + var req fasthttp.Request + + req.Header.SetMethod(expectedMethod) + req.SetRequestURI(expectedRequestURI) + req.Header.SetHost(expectedHost) + req.BodyWriter().Write([]byte(expectedBody)) // nolint:errcheck + for k, v := range expectedHeader { + req.Header.Set(k, v) + } + + remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + fctx.Init(&req, remoteAddr, nil) + app := fiber.New() + ctx := app.AcquireCtx(&fctx) + defer app.ReleaseCtx(ctx) + + fiberH(ctx) + + if callsCount != 1 { + t.Fatalf("unexpected callsCount: %d. Expecting 1", callsCount) + } + + resp := &fctx.Response + if resp.StatusCode() != fiber.StatusBadRequest { + t.Fatalf("unexpected statusCode: %d. Expecting %d", resp.StatusCode(), fiber.StatusBadRequest) + } + if string(resp.Header.Peek("Header1")) != "value1" { + t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header1"), "value1") + } + if string(resp.Header.Peek("Header2")) != "value2" { + t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2") + } + expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) + if string(resp.Body()) != expectedResponseBody { + t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody) + } +} + +func Test_HTTPMiddleware(t *testing.T) { + + tests := []struct { + name string + url string + method string + statusCode int + }{ + { + name: "Should return 200", + url: "/", + method: "POST", + statusCode: 200, + }, + { + name: "Should return 405", + url: "/", + method: "GET", + statusCode: 405, + }, + { + name: "Should return 400", + url: "/unknown", + method: "POST", + statusCode: 404, + }, + } + + nethttpMW := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + next.ServeHTTP(w, r) + }) + } + + app := fiber.New() + app.Use(HTTPMiddleware(nethttpMW)) + app.Post("/", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + for _, tt := range tests { + req, _ := http.NewRequest(tt.method, tt.url, nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf(`%s: %s`, t.Name(), err) + } + if resp.StatusCode != tt.statusCode { + t.Fatalf(`%s: StatusCode: got %v - expected %v`, t.Name(), resp.StatusCode, tt.statusCode) + } + } +} + +func Test_FiberHandler(t *testing.T) { + testFiberToHandlerFunc(t, false) +} + +func Test_FiberApp(t *testing.T) { + testFiberToHandlerFunc(t, false, fiber.New()) +} + +func Test_FiberHandlerDefaultPort(t *testing.T) { + testFiberToHandlerFunc(t, true) +} + +func Test_FiberAppDefaultPort(t *testing.T) { + testFiberToHandlerFunc(t, true, fiber.New()) +} + +func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) { + expectedMethod := fiber.MethodPost + expectedRequestURI := "/foo/bar?baz=123" + expectedBody := "body 123 foo bar baz" + expectedContentLength := len(expectedBody) + expectedHost := "foobar.com" + expectedRemoteAddr := "1.2.3.4:6789" + if checkDefaultPort { + expectedRemoteAddr = "1.2.3.4:80" + } + expectedHeader := map[string]string{ + "Foo-Bar": "baz", + "Abc": "defg", + "XXX-Remote-Addr": "123.43.4543.345", + } + expectedURL, err := url.ParseRequestURI(expectedRequestURI) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + callsCount := 0 + fiberH := func(c *fiber.Ctx) error { + callsCount++ + if c.Method() != expectedMethod { + t.Fatalf("unexpected method %q. Expecting %q", c.Method(), expectedMethod) + } + if string(c.Context().RequestURI()) != expectedRequestURI { + t.Fatalf("unexpected requestURI %q. Expecting %q", string(c.Context().RequestURI()), expectedRequestURI) + } + contentLength := c.Context().Request.Header.ContentLength() + if contentLength != expectedContentLength { + t.Fatalf("unexpected contentLength %d. Expecting %d", contentLength, expectedContentLength) + } + if c.Hostname() != expectedHost { + t.Fatalf("unexpected host %q. Expecting %q", c.Hostname(), expectedHost) + } + remoteAddr := c.Context().RemoteAddr().String() + if remoteAddr != expectedRemoteAddr { + t.Fatalf("unexpected remoteAddr %q. Expecting %q", remoteAddr, expectedRemoteAddr) + } + body := string(c.Body()) + if body != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + if c.OriginalURL() != expectedURL.String() { + t.Fatalf("unexpected URL: %#v. Expecting %#v", c.OriginalURL(), expectedURL) + } + + for k, expectedV := range expectedHeader { + v := c.Get(k) + if v != expectedV { + t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV) + } + } + + c.Set("Header1", "value1") + c.Set("Header2", "value2") + c.Status(fiber.StatusBadRequest) + _, err := c.Write([]byte(fmt.Sprintf("request body is %q", body))) + return err + } + + var handlerFunc http.HandlerFunc + if len(app) > 0 { + app[0].Post("/foo/bar", fiberH) + handlerFunc = FiberApp(app[0]) + } else { + handlerFunc = FiberHandlerFunc(fiberH) + } + + var r http.Request + + r.Method = expectedMethod + r.Body = &netHTTPBody{[]byte(expectedBody)} + r.RequestURI = expectedRequestURI + r.ContentLength = int64(expectedContentLength) + r.Host = expectedHost + r.RemoteAddr = expectedRemoteAddr + if checkDefaultPort { + r.RemoteAddr = "1.2.3.4" + } + + hdr := make(http.Header) + for k, v := range expectedHeader { + hdr.Set(k, v) + } + r.Header = hdr + + var w netHTTPResponseWriter + handlerFunc.ServeHTTP(&w, &r) + + if w.StatusCode() != http.StatusBadRequest { + t.Fatalf("unexpected statusCode: %d. Expecting %d", w.StatusCode(), http.StatusBadRequest) + } + if w.Header().Get("Header1") != "value1" { + t.Fatalf("unexpected header value: %q. Expecting %q", w.Header().Get("Header1"), "value1") + } + if w.Header().Get("Header2") != "value2" { + t.Fatalf("unexpected header value: %q. Expecting %q", w.Header().Get("Header2"), "value2") + } + expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) + if string(w.body) != expectedResponseBody { + t.Fatalf("unexpected response body %q. Expecting %q", string(w.body), expectedResponseBody) + } +} + +func setFiberContextValueMiddleware(next fiber.Handler, key string, value interface{}) fiber.Handler { + return func(c *fiber.Ctx) error { + c.Locals(key, value) + return next(c) + } +} + +func Test_FiberHandler_RequestNilBody(t *testing.T) { + expectedMethod := fiber.MethodGet + expectedRequestURI := "/foo/bar" + expectedContentLength := 0 + + callsCount := 0 + fiberH := func(c *fiber.Ctx) error { + callsCount++ + if c.Method() != expectedMethod { + t.Fatalf("unexpected method %q. Expecting %q", c.Method(), expectedMethod) + } + if string(c.Request().RequestURI()) != expectedRequestURI { + t.Fatalf("unexpected requestURI %q. Expecting %q", string(c.Request().RequestURI()), expectedRequestURI) + } + contentLength := c.Request().Header.ContentLength() + if contentLength != expectedContentLength { + t.Fatalf("unexpected contentLength %d. Expecting %d", contentLength, expectedContentLength) + } + + _, err := c.Write([]byte("request body is nil")) + return err + } + nethttpH := FiberHandler(fiberH) + + var r http.Request + + r.Method = expectedMethod + r.RequestURI = expectedRequestURI + + var w netHTTPResponseWriter + nethttpH.ServeHTTP(&w, &r) + + expectedResponseBody := "request body is nil" + if string(w.body) != expectedResponseBody { + t.Fatalf("unexpected response body %q. Expecting %q", string(w.body), expectedResponseBody) + } +} + +type netHTTPBody struct { + b []byte +} + +func (r *netHTTPBody) Read(p []byte) (int, error) { + if len(r.b) == 0 { + return 0, io.EOF + } + n := copy(p, r.b) + r.b = r.b[n:] + return n, nil +} + +func (r *netHTTPBody) Close() error { + r.b = r.b[:0] + return nil +} + +type netHTTPResponseWriter struct { + statusCode int + h http.Header + body []byte +} + +func (w *netHTTPResponseWriter) StatusCode() int { + if w.statusCode == 0 { + return http.StatusOK + } + return w.statusCode +} + +func (w *netHTTPResponseWriter) Header() http.Header { + if w.h == nil { + w.h = make(http.Header) + } + return w.h +} + +func (w *netHTTPResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} + +func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { + w.body = append(w.body, p...) + return len(p), nil +} diff --git a/middleware/helmet/README.md b/middleware/helmet/README.md new file mode 100644 index 00000000..c3152735 --- /dev/null +++ b/middleware/helmet/README.md @@ -0,0 +1,38 @@ +# Helmet + +![Release](https://img.shields.io/github/release/gofiber/helmet.svg) +[![Discord](https://img.shields.io/badge/discord-join%20channel-7289DA)](https://gofiber.io/discord) +![Test](https://github.com/gofiber/helmet/workflows/Test/badge.svg) +![Security](https://github.com/gofiber/helmet/workflows/Security/badge.svg) +![Linter](https://github.com/gofiber/helmet/workflows/Linter/badge.svg) + +### Install +``` +go get -u github.com/gofiber/fiber/v2 +go get -u github.com/gofiber/helmet/v2 +``` +### Example +```go +package main + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/helmet/v2" +) + +func main() { + app := fiber.New() + + app.Use(helmet.New()) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Welcome!") + }) + + app.Listen(":3000") +} +``` +### Test +```curl +curl -I http://localhost:3000 +``` diff --git a/middleware/helmet/helmet.go b/middleware/helmet/helmet.go new file mode 100644 index 00000000..3d70d0a5 --- /dev/null +++ b/middleware/helmet/helmet.go @@ -0,0 +1,110 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://docs.gofiber.io/ +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package helmet + +import ( + "fmt" + + "github.com/gofiber/fiber/v3" +) + +// Config ... +type Config struct { + // Filter defines a function to skip middleware. + // Optional. Default: nil + Filter func(*fiber.Ctx) bool + // XSSProtection + // Optional. Default value "1; mode=block". + XSSProtection string + // ContentTypeNosniff + // Optional. Default value "nosniff". + ContentTypeNosniff string + // XFrameOptions + // Optional. Default value "SAMEORIGIN". + // Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri" + XFrameOptions string + // HSTSMaxAge + // Optional. Default value 0. + HSTSMaxAge int + // HSTSExcludeSubdomains + // Optional. Default value false. + HSTSExcludeSubdomains bool + // ContentSecurityPolicy + // Optional. Default value "". + ContentSecurityPolicy string + // CSPReportOnly + // Optional. Default value false. + CSPReportOnly bool + // HSTSPreloadEnabled + // Optional. Default value false. + HSTSPreloadEnabled bool + // ReferrerPolicy + // Optional. Default value "". + ReferrerPolicy string + + // Permissions-Policy + // Optional. Default value "". + PermissionPolicy string +} + +// New ... +func New(config ...Config) fiber.Handler { + // Init config + var cfg Config + if len(config) > 0 { + cfg = config[0] + } + // Set config default values + if cfg.XSSProtection == "" { + cfg.XSSProtection = "1; mode=block" + } + if cfg.ContentTypeNosniff == "" { + cfg.ContentTypeNosniff = "nosniff" + } + if cfg.XFrameOptions == "" { + cfg.XFrameOptions = "SAMEORIGIN" + } + // Return middleware handler + return func(c *fiber.Ctx) error { + // Filter request to skip middleware + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + if cfg.XSSProtection != "" { + c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection) + } + if cfg.ContentTypeNosniff != "" { + c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff) + } + if cfg.XFrameOptions != "" { + c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions) + } + if (c.Secure() || (c.Get(fiber.HeaderXForwardedProto) == "https")) && cfg.HSTSMaxAge != 0 { + subdomains := "" + if !cfg.HSTSExcludeSubdomains { + subdomains = "; includeSubdomains" + } + if cfg.HSTSPreloadEnabled { + subdomains = fmt.Sprintf("%s; preload", subdomains) + } + c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains)) + } + if cfg.ContentSecurityPolicy != "" { + if cfg.CSPReportOnly { + c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy) + } else { + c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy) + } + } + if cfg.ReferrerPolicy != "" { + c.Set(fiber.HeaderReferrerPolicy, cfg.ReferrerPolicy) + } + if cfg.PermissionPolicy != "" { + c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy) + + } + return c.Next() + } +} diff --git a/middleware/helmet/helmet_test.go b/middleware/helmet/helmet_test.go new file mode 100644 index 00000000..2c4ca1ea --- /dev/null +++ b/middleware/helmet/helmet_test.go @@ -0,0 +1,108 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://docs.gofiber.io/ +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package helmet + +import ( + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" +) + +func Test_Default(t *testing.T) { + app := fiber.New() + + app.Use(New()) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "1; mode=block", resp.Header.Get(fiber.HeaderXXSSProtection)) + utils.AssertEqual(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions)) + utils.AssertEqual(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions)) + utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentSecurityPolicy)) + utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderReferrerPolicy)) + utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderPermissionsPolicy)) +} + +func Test_Filter(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{ + Filter: func(ctx *fiber.Ctx) bool { + return ctx.Path() == "/filter" + }, + ReferrerPolicy: "no-referrer", + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + app.Get("/filter", func(c *fiber.Ctx) error { + return c.SendString("Skipped!") + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy)) + + resp, err = app.Test(httptest.NewRequest("GET", "/filter", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderReferrerPolicy)) +} + +func Test_ContentSecurityPolicy(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{ + ContentSecurityPolicy: "default-src 'none'", + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicy)) +} + +func Test_ContentSecurityPolicyReportOnly(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{ + ContentSecurityPolicy: "default-src 'none'", + CSPReportOnly: true, + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly)) + utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentSecurityPolicy)) +} + +func Test_PermissionsPolicy(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{ + PermissionPolicy: "microphone=()", + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "microphone=()", resp.Header.Get(fiber.HeaderPermissionsPolicy)) +} diff --git a/middleware/keyauth/README.md b/middleware/keyauth/README.md new file mode 100644 index 00000000..6b61e78c --- /dev/null +++ b/middleware/keyauth/README.md @@ -0,0 +1,44 @@ +# Key Authentication + +![Release](https://img.shields.io/github/release/gofiber/keyauth.svg) +[![Discord](https://img.shields.io/badge/discord-join%20channel-7289DA)](https://gofiber.io/discord) +![Test](https://github.com/gofiber/keyauth/workflows/Test/badge.svg) +![Security](https://github.com/gofiber/keyauth/workflows/Security/badge.svg) +![Linter](https://github.com/gofiber/keyauth/workflows/Linter/badge.svg) + +Special thanks to [Jรณzsef Sallai](https://github.com/jozsefsallai) & [Ray Mayemir](https://github.com/raymayemir) + +### Install +``` +go get -u github.com/gofiber/fiber/v2 +go get -u github.com/gofiber/keyauth/v2 +``` +### Example +```go +package main + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/keyauth/v2" +) + +func main() { + app := fiber.New() + + app.Use(keyauth.New(keyauth.Config{ + KeyLookup: "cookie:access_token", + ContextKey: "my_token", + })) + + app.Get("/", func(c *fiber.Ctx) error { + token, _ := c.Locals("my_token").(string) + return c.SendString(token) + }) + + app.Listen(":3000") +} +``` +### Test +```curl +curl -v --cookie "access_token=hello_world" http://localhost:3000 +``` diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go new file mode 100644 index 00000000..c85e8a3a --- /dev/null +++ b/middleware/keyauth/keyauth.go @@ -0,0 +1,188 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber +// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/key_auth.go +package keyauth + +import ( + "errors" + "strings" + + "github.com/gofiber/fiber/v3" +) + +var ( + // When there is no request of the key thrown ErrMissingOrMalformedAPIKey + ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key") +) + +type Config struct { + // Filter defines a function to skip middleware. + // Optional. Default: nil + Filter func(*fiber.Ctx) bool + + // SuccessHandler defines a function which is executed for a valid key. + // Optional. Default: nil + SuccessHandler fiber.Handler + + // ErrorHandler defines a function which is executed for an invalid key. + // It may be used to define a custom error. + // Optional. Default: 401 Invalid or expired key + ErrorHandler fiber.ErrorHandler + + // KeyLookup is a string in the form of ":" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + // - "form:" + // - "param:" + // - "cookie:" + KeyLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // Validator is a function to validate key. + // Optional. Default: nil + Validator func(*fiber.Ctx, string) (bool, error) + + // Context key to store the bearertoken from the token into context. + // Optional. Default: "token". + ContextKey string +} + +// New ... +func New(config ...Config) fiber.Handler { + // Init config + var cfg Config + if len(config) > 0 { + cfg = config[0] + } + + if cfg.SuccessHandler == nil { + cfg.SuccessHandler = func(c *fiber.Ctx) error { + return c.Next() + } + } + if cfg.ErrorHandler == nil { + cfg.ErrorHandler = func(c *fiber.Ctx, err error) error { + if err == ErrMissingOrMalformedAPIKey { + return c.Status(fiber.StatusBadRequest).SendString(err.Error()) + } + return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key") + } + } + if cfg.KeyLookup == "" { + cfg.KeyLookup = "header:" + fiber.HeaderAuthorization + // set AuthScheme as "Bearer" only if KeyLookup is set to default. + if cfg.AuthScheme == "" { + cfg.AuthScheme = "Bearer" + } + } + if cfg.Validator == nil { + cfg.Validator = func(c *fiber.Ctx, t string) (bool, error) { + return true, nil + } + } + if cfg.ContextKey == "" { + cfg.ContextKey = "token" + } + + // Initialize + parts := strings.Split(cfg.KeyLookup, ":") + extractor := keyFromHeader(parts[1], cfg.AuthScheme) + switch parts[0] { + case "query": + extractor = keyFromQuery(parts[1]) + case "form": + extractor = keyFromForm(parts[1]) + case "param": + extractor = keyFromParam(parts[1]) + case "cookie": + extractor = keyFromCookie(parts[1]) + } + + // Return middleware handler + return func(c *fiber.Ctx) error { + // Filter request to skip middleware + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + + // Extract and verify key + key, err := extractor(c) + if err != nil { + return cfg.ErrorHandler(c, err) + } + + valid, err := cfg.Validator(c, key) + + if err == nil && valid { + c.Locals(cfg.ContextKey, key) + return cfg.SuccessHandler(c) + } + return cfg.ErrorHandler(c, err) + } +} + +// keyFromHeader returns a function that extracts api key from the request header. +func keyFromHeader(header string, authScheme string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + auth := c.Get(header) + l := len(authScheme) + if len(auth) > 0 && l == 0 { + return auth, nil + } + if len(auth) > l+1 && auth[:l] == authScheme { + return auth[l+1:], nil + } + return "", ErrMissingOrMalformedAPIKey + } +} + +// keyFromQuery returns a function that extracts api key from the query string. +func keyFromQuery(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + key := c.Query(param) + if key == "" { + return "", ErrMissingOrMalformedAPIKey + } + return key, nil + } +} + +// keyFromForm returns a function that extracts api key from the form. +func keyFromForm(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + key := c.FormValue(param) + if key == "" { + return "", ErrMissingOrMalformedAPIKey + } + return key, nil + } +} + +// keyFromParam returns a function that extracts api key from the url param string. +func keyFromParam(param string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + key := c.Params(param) + if key == "" { + return "", ErrMissingOrMalformedAPIKey + } + return key, nil + } +} + +// keyFromCookie returns a function that extracts api key from the named cookie. +func keyFromCookie(name string) func(c *fiber.Ctx) (string, error) { + return func(c *fiber.Ctx) (string, error) { + key := c.Cookies(name) + if key == "" { + return "", ErrMissingOrMalformedAPIKey + } + return key, nil + } +} diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go new file mode 100644 index 00000000..47dd09e0 --- /dev/null +++ b/middleware/keyauth/keyauth_test.go @@ -0,0 +1,5 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package keyauth diff --git a/middleware/redirect/README.md b/middleware/redirect/README.md new file mode 100644 index 00000000..190a2ab5 --- /dev/null +++ b/middleware/redirect/README.md @@ -0,0 +1,48 @@ +# Redirect + +![Release](https://img.shields.io/github/release/gofiber/redirect.svg) +[![Discord](https://img.shields.io/badge/discord-join%20channel-7289DA)](https://gofiber.io/discord) +![Test](https://github.com/gofiber/redirect/workflows/Test/badge.svg) +![Security](https://github.com/gofiber/redirect/workflows/Security/badge.svg) +![Linter](https://github.com/gofiber/redirect/workflows/Linter/badge.svg) + +### Install +``` +go get -u github.com/gofiber/fiber/v2 +go get -u github.com/gofiber/redirect/v2 +``` +### Example +```go +package main + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/redirect/v2" +) + +func main() { + app := fiber.New() + + app.Use(redirect.New(redirect.Config{ + Rules: map[string]string{ + "/old": "/new", + "/old/*": "/new/$1", + }, + StatusCode: 301, + })) + + app.Get("/new", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + app.Get("/new/*", func(c *fiber.Ctx) error { + return c.SendString("Wildcard: " + c.Params("*")) + }) + + app.Listen(":3000") +} +``` +### Test +```curl +curl http://localhost:3000/old +curl http://localhost:3000/old/hello +``` diff --git a/middleware/redirect/redirect.go b/middleware/redirect/redirect.go new file mode 100644 index 00000000..cb5e198f --- /dev/null +++ b/middleware/redirect/redirect.go @@ -0,0 +1,88 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package redirect + +import ( + "regexp" + "strconv" + "strings" + + "github.com/gofiber/fiber/v3" +) + +// Config ... +type Config struct { + // Filter defines a function to skip middleware. + // Optional. Default: nil + Filter func(*fiber.Ctx) bool + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Required. Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rules map[string]string + // The status code when redirecting + // This is ignored if Redirect is disabled + // Optional. Default: 302 Temporary Redirect + StatusCode int + + rulesRegex map[*regexp.Regexp]string +} + +// New ... +func New(config ...Config) fiber.Handler { + // Init config + var cfg Config + if len(config) > 0 { + cfg = config[0] + } + if cfg.StatusCode == 0 { + cfg.StatusCode = 302 // Temporary Redirect + } + cfg = config[0] + cfg.rulesRegex = map[*regexp.Regexp]string{} + // Initialize + for k, v := range cfg.Rules { + k = strings.Replace(k, "*", "(.*)", -1) + k = k + "$" + cfg.rulesRegex[regexp.MustCompile(k)] = v + } + // Middleware function + return func(c *fiber.Ctx) error { + // Filter request to skip middleware + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + // Rewrite + for k, v := range cfg.rulesRegex { + replacer := captureTokens(k, c.Path()) + if replacer != nil { + return c.Redirect(replacer.Replace(v), cfg.StatusCode) + } + } + return c.Next() + } +} + +// https://github.com/labstack/echo/blob/master/middleware/rewrite.go +func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { + if len(input) > 1 { + input = strings.TrimSuffix(input, "/") + } + groups := pattern.FindAllStringSubmatch(input, -1) + if groups == nil { + return nil + } + values := groups[0][1:] + replace := make([]string, 2*len(values)) + for i, v := range values { + j := 2 * i + replace[j] = "$" + strconv.Itoa(i+1) + replace[j+1] = v + } + return strings.NewReplacer(replace...) +} diff --git a/middleware/redirect/redirect_test.go b/middleware/redirect/redirect_test.go new file mode 100644 index 00000000..76ef5534 --- /dev/null +++ b/middleware/redirect/redirect_test.go @@ -0,0 +1,126 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package redirect + +import ( + "net/http" + "testing" + + "github.com/gofiber/fiber/v3" +) + +func Test_Redirect(t *testing.T) { + app := *fiber.New() + + app.Use(New(Config{ + Rules: map[string]string{ + "/default": "google.com", + }, + StatusCode: 301, + })) + app.Use(New(Config{ + Rules: map[string]string{ + "/default/*": "fiber.wiki", + }, + StatusCode: 307, + })) + app.Use(New(Config{ + Rules: map[string]string{ + "/redirect/*": "$1", + }, + StatusCode: 303, + })) + app.Use(New(Config{ + Rules: map[string]string{ + "/pattern/*": "golang.org", + }, + StatusCode: 302, + })) + + app.Use(New(Config{ + Rules: map[string]string{ + "/": "/swagger", + }, + StatusCode: 301, + })) + + app.Get("/api/*", func(c *fiber.Ctx) error { + return c.SendString("API") + }) + + app.Get("/new", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + tests := []struct { + name string + url string + redirectTo string + statusCode int + }{ + { + name: "should be returns status 302 without a wildcard", + url: "/default", + redirectTo: "google.com", + statusCode: 301, + }, + { + name: "should be returns status 307 using wildcard", + url: "/default/xyz", + redirectTo: "fiber.wiki", + statusCode: 307, + }, + { + name: "should be returns status 303 without set redirectTo to use the default", + url: "/redirect/github.com/gofiber/redirect", + redirectTo: "github.com/gofiber/redirect", + statusCode: 303, + }, + { + name: "should return the status code default", + url: "/pattern/xyz", + redirectTo: "golang.org", + statusCode: 302, + }, + { + name: "access URL without rule", + url: "/new", + statusCode: 200, + }, + { + name: "redirect to swagger route", + url: "/", + redirectTo: "/swagger", + statusCode: 301, + }, + { + name: "no redirect to swagger route", + url: "/api/", + statusCode: 200, + }, + { + name: "no redirect to swagger route #2", + url: "/api/test", + statusCode: 200, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", tt.url, nil) + req.Header.Set("Location", "github.com/gofiber/redirect") + resp, err := app.Test(req) + if err != nil { + t.Fatalf(`%s: %s`, t.Name(), err) + } + if resp.StatusCode != tt.statusCode { + t.Fatalf(`%s: StatusCode: got %v - expected %v`, t.Name(), resp.StatusCode, tt.statusCode) + } + if resp.Header.Get("Location") != tt.redirectTo { + t.Fatalf(`%s: Expecting Location: %s`, t.Name(), tt.redirectTo) + } + }) + } + +} diff --git a/middleware/rewrite/README.md b/middleware/rewrite/README.md new file mode 100644 index 00000000..b06b739b --- /dev/null +++ b/middleware/rewrite/README.md @@ -0,0 +1,48 @@ +# Rewrite + +![Release](https://img.shields.io/github/release/gofiber/rewrite.svg) +[![Discord](https://img.shields.io/badge/discord-join%20channel-7289DA)](https://gofiber.io/discord) +![Test](https://github.com/gofiber/rewrite/workflows/Test/badge.svg) +![Security](https://github.com/gofiber/rewrite/workflows/Security/badge.svg) +![Linter](https://github.com/gofiber/rewrite/workflows/Linter/badge.svg) + +### Install +``` +go get -u github.com/gofiber/fiber/v2 +go get -u github.com/gofiber/rewrite/v2 +``` +### Example +```go +package main + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/rewrite/v2" +) + +func main() { + app := fiber.New() + + app.Use(rewrite.New(rewrite.Config{ + Rules: map[string]string{ + "/old": "/new", + "/old/*": "/new/$1", + }, + })) + + app.Get("/new", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + app.Get("/new/*", func(c *fiber.Ctx) error { + return c.SendString("Wildcard: " + c.Params("*")) + }) + + app.Listen(":3000") +} + +``` +### Test +```curl +curl http://localhost:3000/old +curl http://localhost:3000/old/hello +``` diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go new file mode 100644 index 00000000..276f9b9a --- /dev/null +++ b/middleware/rewrite/rewrite.go @@ -0,0 +1,89 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package rewrite + +import ( + "regexp" + "strconv" + "strings" + + "github.com/gofiber/fiber/v3" +) + +// Config ... +type Config struct { + // Filter defines a function to skip middleware. + // Optional. Default: nil + Filter func(*fiber.Ctx) bool + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Required. Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rules map[string]string + // // Redirect determns if the client should be redirected + // // By default this is disabled and urls are rewritten on the server + // // Optional. Default: false + // Redirect bool + // // The status code when redirecting + // // This is ignored if Redirect is disabled + // // Optional. Default: 302 Temporary Redirect + // StatusCode int + rulesRegex map[*regexp.Regexp]string +} + +// New ... +func New(config ...Config) fiber.Handler { + // Init config + var cfg Config + if len(config) > 0 { + cfg = config[0] + } + // if cfg.StatusCode == 0 { + // cfg.StatusCode = 302 // Temporary Redirect + // } + cfg = config[0] + cfg.rulesRegex = map[*regexp.Regexp]string{} + // Initialize + for k, v := range cfg.Rules { + k = strings.Replace(k, "*", "(.*)", -1) + k = k + "$" + cfg.rulesRegex[regexp.MustCompile(k)] = v + } + // Middleware function + return func(c *fiber.Ctx) error { + // Filter request to skip middleware + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + // Rewrite + for k, v := range cfg.rulesRegex { + replacer := captureTokens(k, c.Path()) + if replacer != nil { + c.Path(replacer.Replace(v)) + break + } + } + return c.Next() + } +} + +// https://github.com/labstack/echo/blob/master/middleware/rewrite.go +func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { + groups := pattern.FindAllStringSubmatch(input, -1) + if groups == nil { + return nil + } + values := groups[0][1:] + replace := make([]string, 2*len(values)) + for i, v := range values { + j := 2 * i + replace[j] = "$" + strconv.Itoa(i+1) + replace[j+1] = v + } + return strings.NewReplacer(replace...) +} diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go new file mode 100644 index 00000000..73704321 --- /dev/null +++ b/middleware/rewrite/rewrite_test.go @@ -0,0 +1,5 @@ +// ๐Ÿš€ Fiber is an Express inspired web framework written in Go with ๐Ÿ’– +// ๐Ÿ“Œ API Documentation: https://fiber.wiki +// ๐Ÿ“ Github Repository: https://github.com/gofiber/fiber + +package rewrite