From 10d6f69a89390397aa5dd23f4f7352ea05638f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Fri, 23 Sep 2022 09:17:34 +0300 Subject: [PATCH] :sparkles: v3 (feature): new redirection methods (#2014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :sparkles: v3 (feature): new redirection methods * add flash messages * withinput, parsing flash message * add tests * add benchmarks * gosec issues * fix tests * fix tests * fix performance issues * fix performance issues * optimization. * better names * fix tests * Update router.go * fix * fix old messaages with flash messages behavior, add new test case with req * complete some reviews * add pool for redirection. * use constant * update * ✨ v3 (feature): new redirection methods * fix tests, optimize cookie parsing (9 allocs -> 1 alloc) * test case for message includes comma * cleanup * optimization. * some improvements for the redirect feature * fix Benchmark_Redirect_Route_WithFlashMessages * some improvements for the redirect feature * Update redirect.go * improve message parsing & test case Co-authored-by: René Werner --- app.go | 5 +- client_test.go | 4 +- ctx.go | 58 +--- ctx_interface.go | 21 +- ctx_test.go | 179 ----------- error.go | 5 + middleware/expvar/expvar.go | 2 +- middleware/pprof/pprof.go | 2 +- middleware/redirect/redirect.go | 2 +- redirect.go | 295 ++++++++++++++++++ redirect_test.go | 528 ++++++++++++++++++++++++++++++++ router.go | 5 + 12 files changed, 861 insertions(+), 245 deletions(-) create mode 100644 redirect.go create mode 100644 redirect_test.go diff --git a/app.go b/app.go index 708b16a1..a4e24d1f 100644 --- a/app.go +++ b/app.go @@ -10,6 +10,8 @@ package fiber import ( "bufio" "bytes" + "encoding/json" + "encoding/xml" "errors" "fmt" "net" @@ -22,9 +24,6 @@ import ( "sync/atomic" "time" - "encoding/json" - "encoding/xml" - "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) diff --git a/client_test.go b/client_test.go index e3d78dca..670d8128 100644 --- a/client_test.go +++ b/client_test.go @@ -1078,9 +1078,9 @@ func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { app.Get("/", func(c Ctx) error { if c.Request().URI().QueryArgs().Has("foo") { - return c.Redirect("/foo") + return c.Redirect().To("/foo") } - return c.Redirect("/") + return c.Redirect().To("/") }) app.Get("/foo", func(c Ctx) error { return c.SendString("redirect") diff --git a/ctx.go b/ctx.go index 5d6c57df..1b7cb846 100644 --- a/ctx.go +++ b/ctx.go @@ -52,6 +52,8 @@ type DefaultCtx struct { matched bool // Non use route matched viewBindMap *dictpool.Dict // Default view map to bind template engine bind *Bind // Default bind reference + redirect *Redirect // Default redirect reference + redirectionMessages []string // Messages of the previous redirect } // TLSHandler object @@ -896,16 +898,17 @@ func (c *DefaultCtx) Range(size int) (rangeData Range, err error) { return } -// Redirect to the URL derived from the specified path, with specified status. +// Redirect returns the Redirect reference. +// Use Redirect().Status() to set custom redirection status code. // If status is not specified, status defaults to 302 Found. -func (c *DefaultCtx) Redirect(location string, status ...int) error { - c.setCanonical(HeaderLocation, location) - if len(status) > 0 { - c.Status(status[0]) - } else { - c.Status(StatusFound) +// You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection. +func (c *DefaultCtx) Redirect() *Redirect { + if c.redirect == nil { + c.redirect = AcquireRedirect() + c.redirect.c = c } - return nil + + return c.redirect } // Add vars to default view var map binding to template engine. @@ -956,45 +959,6 @@ func (c *DefaultCtx) GetRouteURL(routeName string, params Map) (string, error) { return c.getLocationFromRoute(c.App().GetRoute(routeName), params) } -// RedirectToRoute to the Route registered in the app with appropriate parameters -// If status is not specified, status defaults to 302 Found. -// If you want to send queries to route, you must add "queries" key typed as map[string]string to params. -func (c *DefaultCtx) RedirectToRoute(routeName string, params Map, status ...int) error { - location, err := c.getLocationFromRoute(c.App().GetRoute(routeName), params) - if err != nil { - return err - } - - // Check queries - if queries, ok := params["queries"].(map[string]string); ok { - queryText := bytebufferpool.Get() - defer bytebufferpool.Put(queryText) - - i := 1 - for k, v := range queries { - _, _ = queryText.WriteString(k + "=" + v) - - if i != len(queries) { - _, _ = queryText.WriteString("&") - } - i++ - } - - return c.Redirect(location+"?"+queryText.String(), status...) - } - return c.Redirect(location, status...) -} - -// RedirectBack to the URL to referer -// If status is not specified, status defaults to 302 Found. -func (c *DefaultCtx) RedirectBack(fallback string, status ...int) error { - location := c.Get(HeaderReferer) - if location == "" { - location = fallback - } - return c.Redirect(location, status...) -} - // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template func (c *DefaultCtx) Render(name string, bind Map, layouts ...string) error { diff --git a/ctx_interface.go b/ctx_interface.go index 3adcea49..a050356d 100644 --- a/ctx_interface.go +++ b/ctx_interface.go @@ -229,9 +229,11 @@ type Ctx interface { // Range returns a struct containing the type and a slice of ranges. Range(size int) (rangeData Range, err error) - // Redirect to the URL derived from the specified path, with specified status. + // Redirect returns the Redirect reference. + // Use Redirect().Status() to set custom redirection status code. // If status is not specified, status defaults to 302 Found. - Redirect(location string, status ...int) error + // You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection. + Redirect() *Redirect // Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. @@ -240,15 +242,6 @@ type Ctx interface { // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error) - // RedirectToRoute to the Route registered in the app with appropriate parameters - // If status is not specified, status defaults to 302 Found. - // If you want to send queries to route, you must add "queries" key typed as map[string]string to params. - RedirectToRoute(routeName string, params Map, status ...int) error - - // RedirectBack to the URL to referer - // If status is not specified, status defaults to 302 Found. - RedirectBack(fallback string, status ...int) error - // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template Render(name string, bind Map, layouts ...string) error @@ -445,8 +438,14 @@ func (c *DefaultCtx) release() { c.route = nil c.fasthttp = nil c.bind = nil + c.redirectionMessages = c.redirectionMessages[:0] if c.viewBindMap != nil { dictpool.ReleaseDict(c.viewBindMap) + c.viewBindMap = nil + } + if c.redirect != nil { + ReleaseRedirect(c.redirect) + c.redirect = nil } } diff --git a/ctx_test.go b/ctx_test.go index 82e21d39..f0552a11 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -16,7 +16,6 @@ import ( "io" "mime/multipart" "net/http/httptest" - "net/url" "os" "path/filepath" "strconv" @@ -2254,137 +2253,6 @@ func Test_Ctx_Next_Error(t *testing.T) { require.Equal(t, "Works", resp.Header.Get("X-Next-Result")) } -// go test -run Test_Ctx_Redirect -func Test_Ctx_Redirect(t *testing.T) { - t.Parallel() - app := New() - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.Redirect("http://default.com") - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "http://default.com", string(c.Response().Header.Peek(HeaderLocation))) - - c.Redirect("http://example.com", 301) - require.Equal(t, 301, c.Response().StatusCode()) - require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) -} - -// go test -run Test_Ctx_RedirectToRouteWithParams -func Test_Ctx_RedirectToRouteWithParams(t *testing.T) { - t.Parallel() - app := New() - app.Get("/user/:name", func(c Ctx) error { - return c.JSON(c.Params("name")) - }).Name("user") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.RedirectToRoute("user", Map{ - "name": "fiber", - }) - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) -} - -// go test -run Test_Ctx_RedirectToRouteWithParams -func Test_Ctx_RedirectToRouteWithQueries(t *testing.T) { - t.Parallel() - app := New() - app.Get("/user/:name", func(c Ctx) error { - return c.JSON(c.Params("name")) - }).Name("user") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.RedirectToRoute("user", Map{ - "name": "fiber", - "queries": map[string]string{"data[0][name]": "john", "data[0][age]": "10", "test": "doe"}, - }) - require.Equal(t, 302, c.Response().StatusCode()) - // analysis of query parameters with url parsing, since a map pass is always randomly ordered - location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) - require.NoError(t, err, "url.Parse(location)") - require.Equal(t, "/user/fiber", location.Path) - require.Equal(t, url.Values{"data[0][name]": []string{"john"}, "data[0][age]": []string{"10"}, "test": []string{"doe"}}, location.Query()) -} - -// go test -run Test_Ctx_RedirectToRouteWithOptionalParams -func Test_Ctx_RedirectToRouteWithOptionalParams(t *testing.T) { - t.Parallel() - app := New() - app.Get("/user/:name?", func(c Ctx) error { - return c.JSON(c.Params("name")) - }).Name("user") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.RedirectToRoute("user", Map{ - "name": "fiber", - }) - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) -} - -// go test -run Test_Ctx_RedirectToRouteWithOptionalParamsWithoutValue -func Test_Ctx_RedirectToRouteWithOptionalParamsWithoutValue(t *testing.T) { - t.Parallel() - app := New() - app.Get("/user/:name?", func(c Ctx) error { - return c.JSON(c.Params("name")) - }).Name("user") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.RedirectToRoute("user", Map{}) - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/user/", string(c.Response().Header.Peek(HeaderLocation))) -} - -// go test -run Test_Ctx_RedirectToRouteWithGreedyParameters -func Test_Ctx_RedirectToRouteWithGreedyParameters(t *testing.T) { - t.Parallel() - app := New() - app.Get("/user/+", func(c Ctx) error { - return c.JSON(c.Params("+")) - }).Name("user") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.RedirectToRoute("user", Map{ - "+": "test/routes", - }) - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/user/test/routes", string(c.Response().Header.Peek(HeaderLocation))) -} - -// go test -run Test_Ctx_RedirectBack -func Test_Ctx_RedirectBack(t *testing.T) { - t.Parallel() - app := New() - app.Get("/", func(c Ctx) error { - return c.JSON("Home") - }).Name("home") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.RedirectBack("/") - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation))) -} - -// go test -run Test_Ctx_RedirectBackWithReferer -func Test_Ctx_RedirectBackWithReferer(t *testing.T) { - t.Parallel() - app := New() - app.Get("/", func(c Ctx) error { - return c.JSON("Home") - }).Name("home") - app.Get("/back", func(c Ctx) error { - return c.JSON("Back") - }).Name("back") - c := app.NewCtx(&fasthttp.RequestCtx{}) - - c.Request().Header.Set(HeaderReferer, "/back") - c.RedirectBack("/") - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/back", c.Get(HeaderReferer)) - require.Equal(t, "/back", string(c.Response().Header.Peek(HeaderLocation))) -} - // go test -run Test_Ctx_Render func Test_Ctx_Render(t *testing.T) { t.Parallel() @@ -2597,53 +2465,6 @@ func Benchmark_Ctx_RenderWithLocalsAndBindVars(b *testing.B) { require.Equal(b, "

Hello, World! Test

", string(c.Response().Body())) } -func Benchmark_Ctx_RedirectToRoute(b *testing.B) { - app := New() - app.Get("/user/:name", func(c Ctx) error { - return c.JSON(c.Params("name")) - }).Name("user") - - c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) - - b.ReportAllocs() - b.ResetTimer() - - for n := 0; n < b.N; n++ { - c.RedirectToRoute("user", Map{ - "name": "fiber", - }) - } - - require.Equal(b, 302, c.Response().StatusCode()) - require.Equal(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) -} - -func Benchmark_Ctx_RedirectToRouteWithQueries(b *testing.B) { - app := New() - app.Get("/user/:name", func(c Ctx) error { - return c.JSON(c.Params("name")) - }).Name("user") - - c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) - - b.ReportAllocs() - b.ResetTimer() - - for n := 0; n < b.N; n++ { - c.RedirectToRoute("user", Map{ - "name": "fiber", - "queries": map[string]string{"a": "a", "b": "b"}, - }) - } - - require.Equal(b, 302, c.Response().StatusCode()) - // analysis of query parameters with url parsing, since a map pass is always randomly ordered - location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) - require.NoError(b, err, "url.Parse(location)") - require.Equal(b, "/user/fiber", location.Path) - require.Equal(b, url.Values{"a": []string{"a"}, "b": []string{"b"}}, location.Query()) -} - func Benchmark_Ctx_RenderLocals(b *testing.B) { engine := &testTemplateEngine{} err := engine.Load() diff --git a/error.go b/error.go index 6b08acef..a688411b 100644 --- a/error.go +++ b/error.go @@ -12,6 +12,11 @@ var ( ErrGracefulTimeout = stdErrors.New("shutdown: graceful timeout has been reached, exiting") ) +// Fiber redirection errors +var ( + ErrRedirectBackNoFallback = NewError(StatusInternalServerError, "Referer not found, you have to enter fallback URL for redirection.") +) + // Range errors var ( ErrRangeMalformed = stdErrors.New("range: malformed range header string") diff --git a/middleware/expvar/expvar.go b/middleware/expvar/expvar.go index 3b4861ec..8ea09e1a 100644 --- a/middleware/expvar/expvar.go +++ b/middleware/expvar/expvar.go @@ -29,6 +29,6 @@ func New(config ...Config) fiber.Handler { return nil } - return c.Redirect("/debug/vars", 302) + return c.Redirect().To("/debug/vars") } } diff --git a/middleware/pprof/pprof.go b/middleware/pprof/pprof.go index e3a5ea10..afaab440 100644 --- a/middleware/pprof/pprof.go +++ b/middleware/pprof/pprof.go @@ -72,7 +72,7 @@ func New(config ...Config) fiber.Handler { path = "/debug/pprof/" } - return c.Redirect(path, fiber.StatusFound) + return c.Redirect().To(path) } return nil } diff --git a/middleware/redirect/redirect.go b/middleware/redirect/redirect.go index 8473df3a..0dec5263 100644 --- a/middleware/redirect/redirect.go +++ b/middleware/redirect/redirect.go @@ -61,7 +61,7 @@ func New(config ...Config) fiber.Handler { for k, v := range cfg.rulesRegex { replacer := captureTokens(k, c.Path()) if replacer != nil { - return c.Redirect(replacer.Replace(v), cfg.StatusCode) + return c.Redirect().Status(cfg.StatusCode).To(replacer.Replace(v)) } } return c.Next() diff --git a/redirect.go b/redirect.go new file mode 100644 index 00000000..25f01fde --- /dev/null +++ b/redirect.go @@ -0,0 +1,295 @@ +// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ +// 📝 Github Repository: https://github.com/gofiber/fiber +// 📌 API Documentation: https://docs.gofiber.io + +package fiber + +import ( + "strings" + "sync" + + "github.com/gofiber/fiber/v3/binder" + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/bytebufferpool" +) + +var ( + // Pool for redirection + redirectPool = sync.Pool{ + New: func() any { + return &Redirect{ + status: StatusFound, + oldInput: make(map[string]string, 0), + } + }, + } +) + +// Cookie name to send flash messages when to use redirection. +const ( + FlashCookieName = "fiber_flash" + OldInputDataPrefix = "old_input_data_" + CookieDataSeparator = "," + CookieDataAssigner = ":" +) + +// Redirect is a struct to use it with Ctx. +type Redirect struct { + c *DefaultCtx // Embed ctx + status int // Status code of redirection. Default: StatusFound + + messages []string // Flash messages + oldInput map[string]string // Old input data +} + +// A config to use with Redirect().Route() +// You can specify queries or route parameters. +// NOTE: We don't use net/url to parse parameters because of it has poor performance. You have to pass map. +type RedirectConfig struct { + Params Map // Route parameters + Queries map[string]string // Query map +} + +// AcquireRedirect return default Redirect reference from the redirect pool +func AcquireRedirect() *Redirect { + return redirectPool.Get().(*Redirect) +} + +// ReleaseRedirect returns c acquired via Redirect to redirect pool. +// +// It is forbidden accessing req and/or its' members after returning +// it to redirect pool. +func ReleaseRedirect(r *Redirect) { + r.release() + redirectPool.Put(r) +} + +func (r *Redirect) release() { + r.status = 302 + r.messages = r.messages[:0] + // reset map + for k := range r.oldInput { + delete(r.oldInput, k) + } + r.c = nil +} + +// Status sets the status code of redirection. +// If status is not specified, status defaults to 302 Found. +func (r *Redirect) Status(code int) *Redirect { + r.status = code + + return r +} + +// You can send flash messages by using With(). +// They will be sent as a cookie. +// You can get them by using: Redirect().Messages(), Redirect().Message() +// Note: You must use escape char before using ',' and ':' chars to avoid wrong parsing. +func (r *Redirect) With(key string, value string) *Redirect { + r.messages = append(r.messages, key+CookieDataAssigner+value) + + return r +} + +// You can send input data by using WithInput(). +// They will be sent as a cookie. +// This method can send form, multipart form, query data to redirected route. +// You can get them by using: Redirect().OldInputs(), Redirect().OldInput() +func (r *Redirect) WithInput() *Redirect { + // Get content-type + ctype := utils.ToLower(utils.UnsafeString(r.c.Context().Request.Header.ContentType())) + ctype = binder.FilterFlags(utils.ParseVendorSpecificContentType(ctype)) + + switch ctype { + case MIMEApplicationForm: + _ = r.c.Bind().Form(r.oldInput) + case MIMEMultipartForm: + _ = r.c.Bind().MultipartForm(r.oldInput) + default: + _ = r.c.Bind().Query(r.oldInput) + } + + return r +} + +// Get flash messages. +func (r *Redirect) Messages() map[string]string { + msgs := r.c.redirectionMessages + flashMessages := make(map[string]string, len(msgs)) + + for _, msg := range msgs { + k, v := parseMessage(msg) + + if !strings.HasPrefix(k, OldInputDataPrefix) { + flashMessages[k] = v + } + } + + return flashMessages +} + +// Get flash message by key. +func (r *Redirect) Message(key string) string { + msgs := r.c.redirectionMessages + + for _, msg := range msgs { + k, v := parseMessage(msg) + + if !strings.HasPrefix(k, OldInputDataPrefix) && k == key { + return v + } + } + return "" +} + +// Get old input data. +func (r *Redirect) OldInputs() map[string]string { + msgs := r.c.redirectionMessages + oldInputs := make(map[string]string, len(msgs)) + + for _, msg := range msgs { + k, v := parseMessage(msg) + + if strings.HasPrefix(k, OldInputDataPrefix) { + // remove "old_input_data_" part from key + oldInputs[k[len(OldInputDataPrefix):]] = v + } + } + return oldInputs +} + +// Get old input data by key. +func (r *Redirect) OldInput(key string) string { + msgs := r.c.redirectionMessages + + for _, msg := range msgs { + k, v := parseMessage(msg) + + if strings.HasPrefix(k, OldInputDataPrefix) && k[len(OldInputDataPrefix):] == key { + return v + } + } + return "" + +} + +// Redirect to the URL derived from the specified path, with specified status. +func (r *Redirect) To(location string) error { + r.c.setCanonical(HeaderLocation, location) + r.c.Status(r.status) + + return nil +} + +// Route redirects to the Route registered in the app with appropriate parameters. +// If you want to send queries or params to route, you should use config parameter. +func (r *Redirect) Route(name string, config ...RedirectConfig) error { + // Check config + cfg := RedirectConfig{} + if len(config) > 0 { + cfg = config[0] + } + + // Get location from route name + location, err := r.c.getLocationFromRoute(r.c.App().GetRoute(name), cfg.Params) + if err != nil { + return err + } + + // Flash messages + if len(r.messages) > 0 || len(r.oldInput) > 0 { + messageText := bytebufferpool.Get() + defer bytebufferpool.Put(messageText) + + // flash messages + for i, message := range r.messages { + _, _ = messageText.WriteString(message) + // when there are more messages or oldInput -> add a comma + if len(r.messages)-1 != i || (len(r.messages)-1 == i && len(r.oldInput) > 0) { + _, _ = messageText.WriteString(CookieDataSeparator) + } + } + r.messages = r.messages[:0] + + // old input data + i := 1 + for k, v := range r.oldInput { + _, _ = messageText.WriteString(OldInputDataPrefix + k + CookieDataAssigner + v) + if len(r.oldInput) != i { + _, _ = messageText.WriteString(CookieDataSeparator) + } + i++ + } + + r.c.Cookie(&Cookie{ + Name: FlashCookieName, + Value: r.c.app.getString(messageText.Bytes()), + SessionOnly: true, + }) + } + + // Check queries + if len(cfg.Queries) > 0 { + queryText := bytebufferpool.Get() + defer bytebufferpool.Put(queryText) + + i := 1 + for k, v := range cfg.Queries { + _, _ = queryText.WriteString(k + "=" + v) + + if i != len(cfg.Queries) { + _, _ = queryText.WriteString("&") + } + i++ + } + + return r.To(location + "?" + r.c.app.getString(queryText.Bytes())) + } + + return r.To(location) +} + +// Redirect back to the URL to referer. +func (r *Redirect) Back(fallback ...string) error { + location := r.c.Get(HeaderReferer) + if location == "" { + // Check fallback URL + if len(fallback) == 0 { + err := ErrRedirectBackNoFallback + r.c.Status(err.Code) + + return err + } + location = fallback[0] + } + + return r.To(location) +} + +// setFlash is a method to get flash messages before removing them +func (r *Redirect) setFlash() { + // parse flash messages + cookieValue := r.c.Cookies(FlashCookieName) + + var commaPos int + for { + commaPos = findNextNonEscapedCharsetPosition(cookieValue, []byte(CookieDataSeparator)) + if commaPos != -1 { + r.c.redirectionMessages = append(r.c.redirectionMessages, strings.Trim(cookieValue[:commaPos], " ")) + cookieValue = cookieValue[commaPos+1:] + } else { + r.c.redirectionMessages = append(r.c.redirectionMessages, strings.Trim(cookieValue, " ")) + break + } + } + + r.c.ClearCookie(FlashCookieName) +} + +func parseMessage(raw string) (key, value string) { + if i := findNextNonEscapedCharsetPosition(raw, []byte(CookieDataAssigner)); i != -1 { + return RemoveEscapeChar(raw[:i]), RemoveEscapeChar(raw[i+1:]) + } + return RemoveEscapeChar(raw), "" +} diff --git a/redirect_test.go b/redirect_test.go new file mode 100644 index 00000000..f603c6fc --- /dev/null +++ b/redirect_test.go @@ -0,0 +1,528 @@ +// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ +// 📝 Github Repository: https://github.com/gofiber/fiber +// 📌 API Documentation: https://docs.gofiber.io + +package fiber + +import ( + "context" + "net" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" +) + +// go test -run Test_Redirect_To +func Test_Redirect_To(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().To("http://default.com") + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "http://default.com", string(c.Response().Header.Peek(HeaderLocation))) + + c.Redirect().Status(301).To("http://example.com") + require.Equal(t, 301, c.Response().StatusCode()) + require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Redirect_Route_WithParams +func Test_Redirect_Route_WithParams(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name", func(c Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().Route("user", RedirectConfig{ + Params: Map{ + "name": "fiber", + }, + }) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Redirect_Route_WithParams_WithQueries +func Test_Redirect_Route_WithParams_WithQueries(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name", func(c Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().Route("user", RedirectConfig{ + Params: Map{ + "name": "fiber", + }, + Queries: map[string]string{"data[0][name]": "john", "data[0][age]": "10", "test": "doe"}, + }) + require.Equal(t, 302, c.Response().StatusCode()) + // analysis of query parameters with url parsing, since a map pass is always randomly ordered + location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) + require.NoError(t, err, "url.Parse(location)") + require.Equal(t, "/user/fiber", location.Path) + require.Equal(t, url.Values{"data[0][name]": []string{"john"}, "data[0][age]": []string{"10"}, "test": []string{"doe"}}, location.Query()) +} + +// go test -run Test_Redirect_Route_WithOptionalParams +func Test_Redirect_Route_WithOptionalParams(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name?", func(c Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().Route("user", RedirectConfig{ + Params: Map{ + "name": "fiber", + }, + }) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Redirect_Route_WithOptionalParamsWithoutValue +func Test_Redirect_Route_WithOptionalParamsWithoutValue(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name?", func(c Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().Route("user") + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user/", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Redirect_Route_WithGreedyParameters +func Test_Redirect_Route_WithGreedyParameters(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/+", func(c Ctx) error { + return c.JSON(c.Params("+")) + }).Name("user") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().Route("user", RedirectConfig{ + Params: Map{ + "+": "test/routes", + }, + }) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user/test/routes", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Redirect_Back +func Test_Redirect_Back(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", func(c Ctx) error { + return c.JSON("Home") + }).Name("home") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Redirect().Back("/") + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation))) + + err := c.Redirect().Back() + require.Equal(t, 500, c.Response().StatusCode()) + require.ErrorAs(t, ErrRedirectBackNoFallback, &err) +} + +// go test -run Test_Redirect_Back_WithReferer +func Test_Redirect_Back_WithReferer(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", func(c Ctx) error { + return c.JSON("Home") + }).Name("home") + app.Get("/back", func(c Ctx) error { + return c.JSON("Back") + }).Name("back") + c := app.NewCtx(&fasthttp.RequestCtx{}) + + c.Request().Header.Set(HeaderReferer, "/back") + c.Redirect().Back("/") + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/back", c.Get(HeaderReferer)) + require.Equal(t, "/back", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Redirect_Route_WithFlashMessages +func Test_Redirect_Route_WithFlashMessages(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Redirect().With("success", "1").With("message", "test").Route("user") + + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) + + equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" + require.True(t, equal) + + c.Redirect().setFlash() + require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) +} + +// go test -run Test_Redirect_Route_WithOldInput +func Test_Redirect_Route_WithOldInput(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().URI().SetQueryString("id=1&name=tom") + c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") + + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) + + require.Contains(t, c.GetRespHeader(HeaderSetCookie), "fiber_flash=") + require.Contains(t, c.GetRespHeader(HeaderSetCookie), "success:1") + require.Contains(t, c.GetRespHeader(HeaderSetCookie), "message:test") + + require.Contains(t, c.GetRespHeader(HeaderSetCookie), ",old_input_data_id:1") + require.Contains(t, c.GetRespHeader(HeaderSetCookie), ",old_input_data_name:tom") + + c.Redirect().setFlash() + require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) +} + +// go test -run Test_Redirect_setFlash +func Test_Redirect_setFlash(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + + c.Redirect().setFlash() + + require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + + require.Equal(t, "1", c.Redirect().Message("success")) + require.Equal(t, "test", c.Redirect().Message("message")) + require.Equal(t, map[string]string{"success": "1", "message": "test"}, c.Redirect().Messages()) + + require.Equal(t, "1", c.Redirect().OldInput("id")) + require.Equal(t, "tom", c.Redirect().OldInput("name")) + require.Equal(t, map[string]string{"id": "1", "name": "tom"}, c.Redirect().OldInputs()) +} + +// go test -run Test_Redirect_Request +func Test_Redirect_Request(t *testing.T) { + t.Parallel() + + app := New() + + app.Get("/", func(c Ctx) error { + return c.Redirect().With("key", "value").With("key2", "value2").With("co\\:m\\,ma", "Fi\\:ber\\, v3").Route("name") + }) + + app.Get("/with-inputs", func(c Ctx) error { + return c.Redirect().WithInput().With("key", "value").With("key2", "value2").Route("name") + }) + + app.Get("/just-inputs", func(c Ctx) error { + return c.Redirect().WithInput().Route("name") + }) + + app.Get("/redirected", func(c Ctx) error { + return c.JSON(Map{ + "messages": c.Redirect().Messages(), + "inputs": c.Redirect().OldInputs(), + }) + }).Name("name") + + // Start test server + ln := fasthttputil.NewInmemoryListener() + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + err := app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + GracefulContext: ctx, + }) + + require.NoError(t, err) + }() + + // Test cases + testCases := []struct { + URL string + CookieValue string + ExpectedBody string + ExpectedStatusCode int + ExceptedErrsLen int + }{ + { + URL: "/", + CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", + ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, + ExpectedStatusCode: StatusOK, + ExceptedErrsLen: 0, + }, + { + URL: "/with-inputs?name=john&surname=doe", + CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", + ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, + ExpectedStatusCode: StatusOK, + ExceptedErrsLen: 0, + }, + { + URL: "/just-inputs?name=john&surname=doe", + CookieValue: "old_input_data_name:john,old_input_data_surname:doe", + ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, + ExpectedStatusCode: StatusOK, + ExceptedErrsLen: 0, + }, + } + + for _, tc := range testCases { + a := Get("http://example.com" + tc.URL) + a.Cookie(FlashCookieName, tc.CookieValue) + a.MaxRedirectsCount(1) + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + code, body, errs := a.String() + + require.Equal(t, tc.ExpectedStatusCode, code) + require.Equal(t, tc.ExpectedBody, body) + require.Equal(t, tc.ExceptedErrsLen, len(errs)) + } +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_Route -benchmem -count=4 +func Benchmark_Redirect_Route(b *testing.B) { + app := New() + app.Get("/user/:name", func(c Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + c.Redirect().Route("user", RedirectConfig{ + Params: Map{ + "name": "fiber", + }, + }) + } + + require.Equal(b, 302, c.Response().StatusCode()) + require.Equal(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_Route_WithQueries -benchmem -count=4 +func Benchmark_Redirect_Route_WithQueries(b *testing.B) { + app := New() + app.Get("/user/:name", func(c Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + c.Redirect().Route("user", RedirectConfig{ + Params: Map{ + "name": "fiber", + }, + Queries: map[string]string{"a": "a", "b": "b"}, + }) + } + + require.Equal(b, 302, c.Response().StatusCode()) + // analysis of query parameters with url parsing, since a map pass is always randomly ordered + location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) + require.NoError(b, err, "url.Parse(location)") + require.Equal(b, "/user/fiber", location.Path) + require.Equal(b, url.Values{"a": []string{"a"}, "b": []string{"b"}}, location.Query()) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_Route_WithFlashMessages -benchmem -count=4 +func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + c.Redirect().With("success", "1").With("message", "test").Route("user") + } + + require.Equal(b, 302, c.Response().StatusCode()) + require.Equal(b, "/user", string(c.Response().Header.Peek(HeaderLocation))) + + equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" + require.True(b, equal) + + c.Redirect().setFlash() + require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_setFlash -benchmem -count=4 +func Benchmark_Redirect_setFlash(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + c.Redirect().setFlash() + } + + require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + + require.Equal(b, "1", c.Redirect().Message("success")) + require.Equal(b, "test", c.Redirect().Message("message")) + require.Equal(b, map[string]string{"success": "1", "message": "test"}, c.Redirect().Messages()) + + require.Equal(b, "1", c.Redirect().OldInput("id")) + require.Equal(b, "tom", c.Redirect().OldInput("name")) + require.Equal(b, map[string]string{"id": "1", "name": "tom"}, c.Redirect().OldInputs()) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_Messages -benchmem -count=4 +func Benchmark_Redirect_Messages(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + c.Redirect().setFlash() + + var msgs map[string]string + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + msgs = c.Redirect().Messages() + } + + require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.Equal(b, map[string]string{"success": "1", "message": "test"}, msgs) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_OldInputs -benchmem -count=4 +func Benchmark_Redirect_OldInputs(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + c.Redirect().setFlash() + + var oldInputs map[string]string + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + oldInputs = c.Redirect().OldInputs() + } + + require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.Equal(b, map[string]string{"id": "1", "name": "tom"}, oldInputs) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_Message -benchmem -count=4 +func Benchmark_Redirect_Message(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + c.Redirect().setFlash() + + var msg string + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + msg = c.Redirect().Message("message") + } + + require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.Equal(b, "test", msg) +} + +// go test -v -run=^$ -bench=Benchmark_Redirect_OldInput -benchmem -count=4 +func Benchmark_Redirect_OldInput(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + c.Redirect().setFlash() + + var input string + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + input = c.Redirect().OldInput("name") + } + + require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.Equal(b, "tom", input) +} diff --git a/router.go b/router.go index 614ff937..1d7851a6 100644 --- a/router.go +++ b/router.go @@ -167,6 +167,11 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) { return } + // check flash messages + if strings.Contains(utils.UnsafeString(c.Request().Header.RawHeaders()), FlashCookieName) { + c.Redirect().setFlash() + } + // Find match in stack _, err := app.next(c, app.newCtxFunc != nil) if err != nil {