diff --git a/redirect.go b/redirect.go index 3ad9cfb3..053b11dc 100644 --- a/redirect.go +++ b/redirect.go @@ -5,6 +5,7 @@ package fiber import ( + "encoding/hex" "errors" "sync" @@ -296,9 +297,12 @@ func (r *Redirect) Back(fallback ...string) error { // parseAndClearFlashMessages is a method to get flash messages before they are getting removed func (r *Redirect) parseAndClearFlashMessages() { // parse flash messages - cookieValue := r.c.Cookies(FlashCookieName) + cookieValue, err := hex.DecodeString(r.c.Cookies(FlashCookieName)) + if err != nil { + return + } - _, err := r.c.flashMessages.UnmarshalMsg(r.c.app.getBytes(cookieValue)) + _, err = r.c.flashMessages.UnmarshalMsg(cookieValue) if err != nil { return } @@ -316,9 +320,12 @@ func (r *Redirect) processFlashMessages() { return } + dst := make([]byte, hex.EncodedLen(len(val))) + hex.Encode(dst, val) + r.c.Cookie(&Cookie{ Name: FlashCookieName, - Value: r.c.app.getString(val), + Value: r.c.app.getString(dst), SessionOnly: true, }) } diff --git a/redirect_test.go b/redirect_test.go index 87c3d80a..c4d50a55 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -6,7 +6,12 @@ package fiber import ( "bytes" + "encoding/hex" + "encoding/json" + "io" "mime/multipart" + "net/http" + "net/http/httptest" "net/url" "testing" @@ -46,7 +51,9 @@ func Test_Redirect_To_WithFlashMessages(t *testing.T) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(t, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 2) @@ -189,7 +196,9 @@ func Test_Redirect_Back_WithFlashMessages(t *testing.T) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(t, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 2) @@ -240,7 +249,9 @@ func Test_Redirect_Route_WithFlashMessages(t *testing.T) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(t, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 2) @@ -277,7 +288,9 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(t, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 4) @@ -313,7 +326,9 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(t, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 4) @@ -357,7 +372,9 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(t, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 4) @@ -403,7 +420,7 @@ func Test_Redirect_parseAndClearFlashMessages(t *testing.T) { val, err := msgs.MarshalMsg(nil) require.NoError(t, err) - c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) + c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() @@ -456,6 +473,166 @@ func Test_Redirect_parseAndClearFlashMessages(t *testing.T) { Value: "1", }, }, c.Redirect().OldInputs()) + + c.Request().Header.Set(HeaderCookie, "fiber_flash=test") + + c.Redirect().parseAndClearFlashMessages() + + require.Empty(t, c.Redirect().messages) +} + +// Test_Redirect_parseAndClearFlashMessages_InvalidHex tests the case where hex decoding fails +func Test_Redirect_parseAndClearFlashMessages_InvalidHex(t *testing.T) { + t.Parallel() + + app := New() + + // Setup request and response + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed + defer app.ReleaseCtx(c) + + // Create redirect instance + r := AcquireRedirect() + r.c = c + + // Set invalid hex value in flash cookie + c.Request().Header.SetCookie(FlashCookieName, "not-a-valid-hex-string") + + // Call parseAndClearFlashMessages + r.parseAndClearFlashMessages() + + // Verify that no flash messages are processed (should be empty) + require.Empty(t, r.messages) + + // Release redirect + ReleaseRedirect(r) +} + +func Test_Redirect_CompleteFlowWithFlashMessages(t *testing.T) { + t.Parallel() + + app := New() + + // First handler that sets flash messages and redirects + app.Get("/source", func(c Ctx) error { + // Redirect to the target handler + return c.Redirect().With("string_message", "Hello, World!"). + With("number_message", "12345"). + With("bool_message", "true"). + To("/target") + }) + + // Second handler that receives and processes flash messages + app.Get("/target", func(c Ctx) error { + // Get all flash messages and return them as a JSON response + return c.JSON(Map{ + "string_message": c.Redirect().Message("string_message").Value, + "number_message": c.Redirect().Message("number_message").Value, + "bool_message": c.Redirect().Message("bool_message").Value, + }) + }) + + // Step 1: Make the initial request to the source route + req := httptest.NewRequest(MethodGet, "/source", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, StatusSeeOther, resp.StatusCode) + require.Equal(t, "/target", resp.Header.Get(HeaderLocation)) + + // Verify and get the cookie from the response + cookies := resp.Cookies() + var flashCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == "fiber_flash" { + flashCookie = cookie + break + } + } + require.NotNil(t, flashCookie, "Flash cookie should be set") + + // Step 2: Make the second request to the target route with the cookie + req = httptest.NewRequest(MethodGet, "/target", nil) + req.Header.Set("Cookie", flashCookie.Name+"="+flashCookie.Value) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, StatusOK, resp.StatusCode) + + // Parse the JSON response and verify flash messages + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(body, &result) + require.NoError(t, err) + + // Verify all flash messages were received correctly + require.Equal(t, "Hello, World!", result["string_message"]) + require.Equal(t, "12345", result["number_message"]) // JSON numbers are float64 + require.Equal(t, "true", result["bool_message"]) +} + +func Test_Redirect_FlashMessagesWithSpecialChars(t *testing.T) { + t.Parallel() + + app := New() + + // Handler that sets flash messages with special characters and redirects + app.Get("/special-source", func(c Ctx) error { + // Create a large message to test encoding of larger data + return c.Redirect().With("null_bytes", "Contains\x00null\x00bytes"). + With("control_chars", "Contains\r\ncontrol\tcharacters"). + With("unicode", "Unicode: δ½ ε₯½δΈ–η•Œ"). + With("emoji", "Emoji: πŸ”₯πŸš€πŸ˜Š"). + To("/special-target") + }) + + // Target handler that receives the flash messages + app.Get("/special-target", func(c Ctx) error { + return c.JSON(Map{ + "null_bytes": c.Redirect().Message("null_bytes").Value, + "control_chars": c.Redirect().Message("control_chars").Value, + "unicode": c.Redirect().Message("unicode").Value, + "emoji": c.Redirect().Message("emoji").Value, + }) + }) + + // Step 1: Make the initial request + req := httptest.NewRequest(MethodGet, "/special-source", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, StatusSeeOther, resp.StatusCode) + require.Equal(t, "/special-target", resp.Header.Get(HeaderLocation)) + + // Get the flash cookie + var flashCookie *http.Cookie + for _, cookie := range resp.Cookies() { + if cookie.Name == "fiber_flash" { + flashCookie = cookie + break + } + } + require.NotNil(t, flashCookie, "Flash cookie should be set") + + // Step 2: Make the second request with the cookie + req = httptest.NewRequest(MethodGet, "/special-target", nil) + req.Header.Set("Cookie", flashCookie.Name+"="+flashCookie.Value) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, StatusOK, resp.StatusCode) + + // Parse and verify the response + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(body, &result) + require.NoError(t, err) + + // Verify special character handling + require.Equal(t, "Contains\x00null\x00bytes", result["null_bytes"]) + require.Equal(t, "Contains\r\ncontrol\tcharacters", result["control_chars"]) + require.Equal(t, "Unicode: δ½ ε₯½δΈ–η•Œ", result["unicode"]) + require.Equal(t, "Emoji: πŸ”₯πŸš€πŸ˜Š", result["emoji"]) } // go test -v -run=^$ -bench=Benchmark_Redirect_Route -benchmem -count=4 @@ -542,7 +719,9 @@ func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(b, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(b, err) require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) @@ -582,7 +761,7 @@ func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) { val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) - c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) + c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) b.ReportAllocs() b.ResetTimer() @@ -633,7 +812,9 @@ func Benchmark_Redirect_processFlashMessages(b *testing.B) { c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs - _, err := msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) + require.NoError(b, err) + _, err = msgs.UnmarshalMsg(decoded) require.NoError(b, err) require.Len(b, msgs, 2) @@ -653,7 +834,7 @@ func Benchmark_Redirect_Messages(b *testing.B) { val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) - c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) + c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var msgs []FlashMessage @@ -690,7 +871,7 @@ func Benchmark_Redirect_OldInputs(b *testing.B) { val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) - c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) + c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var oldInputs []OldInputData @@ -725,7 +906,7 @@ func Benchmark_Redirect_Message(b *testing.B) { val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) - c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) + c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var msg FlashMessage @@ -756,7 +937,7 @@ func Benchmark_Redirect_OldInput(b *testing.B) { val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) - c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) + c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var input OldInputData