diff --git a/Makefile b/Makefile index 905f4f36..4a183da3 100644 --- a/Makefile +++ b/Makefile @@ -57,11 +57,9 @@ tidy: betteralign: go run github.com/dkorunic/betteralign/cmd/betteralign@latest -test_files -generated_files -apply ./... -## tidy: ⚡️ Generate msgp -.PHONY: msgp -msgp: - go run github.com/tinylib/msgp@latest -file="middleware/cache/manager.go" -o="middleware/cache/manager_msgp.go" -tests=true -unexported - go run github.com/tinylib/msgp@latest -file="middleware/session/data.go" -o="middleware/session/data_msgp.go" -tests=true -unexported - go run github.com/tinylib/msgp@latest -file="middleware/csrf/storage_manager.go" -o="middleware/csrf/storage_manager_msgp.go" -tests=true -unexported - go run github.com/tinylib/msgp@latest -file="middleware/limiter/manager.go" -o="middleware/limiter/manager_msgp.go" -tests=true -unexported - go run github.com/tinylib/msgp@latest -file="middleware/idempotency/response.go" -o="middleware/idempotency/response_msgp.go" -tests=true -unexported +## tidy: ⚡️ Generate msgp && interface implementations +.PHONY: generate +generate: + go install github.com/tinylib/msgp@latest + go install github.com/vburenin/ifacemaker@975a95966976eeb2d4365a7fb236e274c54da64c + go generate ./... diff --git a/ctx.go b/ctx.go index 84378d76..4d7417ee 100644 --- a/ctx.go +++ b/ctx.go @@ -63,7 +63,7 @@ type DefaultCtx struct { pathOriginal string // Original HTTP path pathBuffer []byte // HTTP path buffer detectionPathBuffer []byte // HTTP detectionPath buffer - redirectionMessages []string // Messages of the previous redirect + flashMessages redirectionMsgs // Flash messages indexRoute int // Index of the current route indexHandler int // Index of the current handler methodINT int // HTTP method INT equivalent @@ -1896,7 +1896,7 @@ func (c *DefaultCtx) release() { c.route = nil c.fasthttp = nil c.bind = nil - c.redirectionMessages = c.redirectionMessages[:0] + c.flashMessages = c.flashMessages[:0] c.viewBindMap = sync.Map{} if c.redirect != nil { ReleaseRedirect(c.redirect) diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index bffbe79d..7709f7c9 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -330,6 +330,7 @@ type Ctx interface { Reset(fctx *fasthttp.RequestCtx) // Release is a method to reset context fields when to use ReleaseCtx() release() + getBody() []byte // Methods to use with next stack. getMethodINT() int getIndexRoute() int diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index dd83b7a7..3a796c77 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -10,7 +10,7 @@ import ( // msgp -file="manager.go" -o="manager_msgp.go" -tests=true -unexported // -//go:generate msgp +//go:generate msgp -o=manager_msgp.go -tests=true -unexported type item struct { headers map[string][]byte body []byte diff --git a/middleware/csrf/storage_manager.go b/middleware/csrf/storage_manager.go index e572a9c7..4d3c2642 100644 --- a/middleware/csrf/storage_manager.go +++ b/middleware/csrf/storage_manager.go @@ -11,7 +11,7 @@ import ( // msgp -file="storage_manager.go" -o="storage_manager_msgp.go" -tests=true -unexported // -//go:generate msgp +//go:generate msgp -o=storage_manager_msgp.go -tests=true -unexported type item struct{} //msgp:ignore manager diff --git a/middleware/idempotency/response.go b/middleware/idempotency/response.go index 0e47ac86..aafbca60 100644 --- a/middleware/idempotency/response.go +++ b/middleware/idempotency/response.go @@ -3,7 +3,7 @@ package idempotency // response is a struct that represents the response of a request. // generation tool `go install github.com/tinylib/msgp@latest` // -//go:generate msgp -o=response_msgp.go -io=false -tests=true -unexported +//go:generate msgp -o=response_msgp.go -tests=true -unexported type response struct { Headers map[string][]string `msg:"hs"` diff --git a/middleware/limiter/manager.go b/middleware/limiter/manager.go index e91c4013..ecaee7f1 100644 --- a/middleware/limiter/manager.go +++ b/middleware/limiter/manager.go @@ -10,7 +10,7 @@ import ( // msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported // -//go:generate msgp +//go:generate msgp -o=manager_msgp.go -tests=false -unexported type item struct { currHits int prevHits int diff --git a/middleware/session/data.go b/middleware/session/data.go index 8d728723..08cb833f 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -6,7 +6,7 @@ import ( // msgp -file="data.go" -o="data_msgp.go" -tests=true -unexported // -//go:generate msgp +//go:generate msgp -o=data_msgp.go -tests=true -unexported type data struct { Data map[string]any sync.RWMutex `msg:"-"` diff --git a/redirect.go b/redirect.go index c59a00a3..ebbcb499 100644 --- a/redirect.go +++ b/redirect.go @@ -6,7 +6,6 @@ package fiber import ( "errors" - "strings" "sync" "github.com/gofiber/fiber/v3/binder" @@ -19,7 +18,7 @@ var redirectPool = sync.Pool{ New: func() any { return &Redirect{ status: StatusFound, - oldInput: make(map[string]string, 0), + messages: make(redirectionMsgs, 0), } }, } @@ -32,13 +31,37 @@ const ( CookieDataAssigner = ":" ) +// redirectionMsgs is a struct that used to store flash messages and old input data in cookie using MSGP. +// msgp -file="redirect.go" -o="redirect_msgp.go" -unexported +// +//msgp:ignore Redirect RedirectConfig OldInputData FlashMessage +type redirectionMsg struct { + key string + value string + level uint8 + isOldInput bool +} + +type redirectionMsgs []redirectionMsg + +// OldInputData is a struct that holds the old input data. +type OldInputData struct { + Key string + Value string +} + +// FlashMessage is a struct that holds the flash message data. +type FlashMessage struct { + Key string + Value string + Level uint8 +} + // Redirect is a struct that holds the redirect data. type Redirect struct { - c *DefaultCtx // Embed ctx - oldInput map[string]string // Old input data - - messages []string // Flash messages - status int // Status code of redirection. Default: StatusFound + c *DefaultCtx // Embed ctx + messages redirectionMsgs // Flash messages and old input data + status int // Status code of redirection. Default: StatusFound } // RedirectConfig A config to use with Redirect().Route() @@ -71,10 +94,6 @@ func ReleaseRedirect(r *Redirect) { func (r *Redirect) release() { r.status = 302 r.messages = r.messages[:0] - // reset map - for k := range r.oldInput { - delete(r.oldInput, k) - } r.c = nil } @@ -90,8 +109,28 @@ func (r *Redirect) Status(code int) *Redirect { // They will be sent as a cookie. // You can get them by using: Redirect().Messages(), Redirect().Message() // Note: You must use escape char before using ',' and ':' chars to avoid wrong parsing. -func (r *Redirect) With(key, value string) *Redirect { - r.messages = append(r.messages, key+CookieDataAssigner+value) +func (r *Redirect) With(key, value string, level ...uint8) *Redirect { + // Get level + var msgLevel uint8 + if len(level) > 0 { + msgLevel = level[0] + } + + // Override old message if exists + for i, msg := range r.messages { + if msg.key == key && !msg.isOldInput { + r.messages[i].value = value + r.messages[i].level = msgLevel + + return r + } + } + + r.messages = append(r.messages, redirectionMsg{ + key: key, + value: value, + level: msgLevel, + }) return r } @@ -105,28 +144,39 @@ func (r *Redirect) WithInput() *Redirect { ctype := utils.ToLower(utils.UnsafeString(r.c.Context().Request.Header.ContentType())) ctype = binder.FilterFlags(utils.ParseVendorSpecificContentType(ctype)) + oldInput := make(map[string]string) switch ctype { case MIMEApplicationForm: - _ = r.c.Bind().Form(r.oldInput) //nolint:errcheck // not needed + _ = r.c.Bind().Form(oldInput) //nolint:errcheck // not needed case MIMEMultipartForm: - _ = r.c.Bind().MultipartForm(r.oldInput) //nolint:errcheck // not needed + _ = r.c.Bind().MultipartForm(oldInput) //nolint:errcheck // not needed default: - _ = r.c.Bind().Query(r.oldInput) //nolint:errcheck // not needed + _ = r.c.Bind().Query(oldInput) //nolint:errcheck // not needed + } + + // Add old input data + for k, v := range oldInput { + r.messages = append(r.messages, redirectionMsg{ + key: k, + value: v, + isOldInput: true, + }) } return r } // Messages Get flash messages. -func (r *Redirect) Messages() map[string]string { - msgs := r.c.redirectionMessages - flashMessages := make(map[string]string, len(msgs)) +func (r *Redirect) Messages() []FlashMessage { + flashMessages := make([]FlashMessage, 0) - for _, msg := range msgs { - k, v := parseMessage(msg) - - if !strings.HasPrefix(k, OldInputDataPrefix) { - flashMessages[k] = v + for _, msg := range r.c.flashMessages { + if !msg.isOldInput { + flashMessages = append(flashMessages, FlashMessage{ + Key: msg.key, + Value: msg.value, + Level: msg.level, + }) } } @@ -134,47 +184,52 @@ func (r *Redirect) Messages() map[string]string { } // Message Get flash message by key. -func (r *Redirect) Message(key string) string { - msgs := r.c.redirectionMessages +func (r *Redirect) Message(key string) FlashMessage { + msgs := r.c.flashMessages for _, msg := range msgs { - k, v := parseMessage(msg) - - if !strings.HasPrefix(k, OldInputDataPrefix) && k == key { - return v + if msg.key == key && !msg.isOldInput { + return FlashMessage{ + Key: msg.key, + Value: msg.value, + Level: msg.level, + } } } - return "" + + return FlashMessage{} } // OldInputs Get old input data. -func (r *Redirect) OldInputs() map[string]string { - msgs := r.c.redirectionMessages - oldInputs := make(map[string]string, len(msgs)) +func (r *Redirect) OldInputs() []OldInputData { + inputs := make([]OldInputData, 0) - for _, msg := range msgs { - k, v := parseMessage(msg) - - if strings.HasPrefix(k, OldInputDataPrefix) { - // remove "old_input_data_" part from key - oldInputs[k[len(OldInputDataPrefix):]] = v + for _, msg := range r.c.flashMessages { + if msg.isOldInput { + inputs = append(inputs, OldInputData{ + Key: msg.key, + Value: msg.value, + }) } } - return oldInputs + + return inputs } // OldInput Get old input data by key. -func (r *Redirect) OldInput(key string) string { - msgs := r.c.redirectionMessages +func (r *Redirect) OldInput(key string) OldInputData { + msgs := r.c.flashMessages for _, msg := range msgs { - k, v := parseMessage(msg) - - if strings.HasPrefix(k, OldInputDataPrefix) && k[len(OldInputDataPrefix):] == key { - return v + if msg.key == key && msg.isOldInput { + return OldInputData{ + Key: msg.key, + Value: msg.value, + } } } - return "" + + return OldInputData{} } // To redirect to the URL derived from the specified path, with specified status. @@ -240,66 +295,32 @@ func (r *Redirect) Back(fallback ...string) error { return r.To(location) } -// parseAndClearFlashMessages is a method to get flash messages before removing them +// parseAndClearFlashMessages is a method to get flash messages before they are getting removed func (r *Redirect) parseAndClearFlashMessages() { // parse flash messages cookieValue := r.c.Cookies(FlashCookieName) - var commaPos int - for { - commaPos = findNextNonEscapedCharsetPosition(cookieValue, []byte(CookieDataSeparator)) - if commaPos == -1 { - r.c.redirectionMessages = append(r.c.redirectionMessages, utils.Trim(cookieValue, ' ')) - break - } - r.c.redirectionMessages = append(r.c.redirectionMessages, utils.Trim(cookieValue[:commaPos], ' ')) - cookieValue = cookieValue[commaPos+1:] + _, err := r.c.flashMessages.UnmarshalMsg(r.c.app.getBytes(cookieValue)) + if err != nil { + return } - - r.c.ClearCookie(FlashCookieName) } // processFlashMessages is a helper function to process flash messages and old input data // and set them as cookies func (r *Redirect) processFlashMessages() { - // Flash messages - if len(r.messages) > 0 || len(r.oldInput) > 0 { - messageText := bytebufferpool.Get() - defer bytebufferpool.Put(messageText) - - // flash messages - for i, message := range r.messages { - messageText.WriteString(message) - // when there are more messages or oldInput -> add a comma - if len(r.messages)-1 != i || (len(r.messages)-1 == i && len(r.oldInput) > 0) { - messageText.WriteString(CookieDataSeparator) - } - } - r.messages = r.messages[:0] - - // old input data - i := 1 - for k, v := range r.oldInput { - messageText.WriteString(OldInputDataPrefix + k + CookieDataAssigner + v) - if len(r.oldInput) != i { - messageText.WriteString(CookieDataSeparator) - } - i++ - } - - r.c.Cookie(&Cookie{ - Name: FlashCookieName, - Value: r.c.app.getString(messageText.Bytes()), - SessionOnly: true, - }) - } -} - -// parseMessage is a helper function to parse flash messages and old input data -func parseMessage(raw string) (string, string) { //nolint: revive // not necessary - if i := findNextNonEscapedCharsetPosition(raw, []byte(CookieDataAssigner)); i != -1 { - return RemoveEscapeChar(raw[:i]), RemoveEscapeChar(raw[i+1:]) + if len(r.messages) == 0 { + return } - return RemoveEscapeChar(raw), "" + val, err := r.messages.MarshalMsg(nil) + if err != nil { + return + } + + r.c.Cookie(&Cookie{ + Name: FlashCookieName, + Value: r.c.app.getString(val), + SessionOnly: true, + }) } diff --git a/redirect_msgp.go b/redirect_msgp.go new file mode 100644 index 00000000..da11533d --- /dev/null +++ b/redirect_msgp.go @@ -0,0 +1,272 @@ +package fiber + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *redirectionMsg) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "key": + z.key, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "key") + return + } + case "value": + z.value, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "value") + return + } + case "level": + z.level, err = dc.ReadUint8() + if err != nil { + err = msgp.WrapError(err, "level") + return + } + case "isOldInput": + z.isOldInput, err = dc.ReadBool() + if err != nil { + err = msgp.WrapError(err, "isOldInput") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *redirectionMsg) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 4 + // write "key" + err = en.Append(0x84, 0xa3, 0x6b, 0x65, 0x79) + if err != nil { + return + } + err = en.WriteString(z.key) + if err != nil { + err = msgp.WrapError(err, "key") + return + } + // write "value" + err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) + if err != nil { + return + } + err = en.WriteString(z.value) + if err != nil { + err = msgp.WrapError(err, "value") + return + } + // write "level" + err = en.Append(0xa5, 0x6c, 0x65, 0x76, 0x65, 0x6c) + if err != nil { + return + } + err = en.WriteUint8(z.level) + if err != nil { + err = msgp.WrapError(err, "level") + return + } + // write "isOldInput" + err = en.Append(0xaa, 0x69, 0x73, 0x4f, 0x6c, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74) + if err != nil { + return + } + err = en.WriteBool(z.isOldInput) + if err != nil { + err = msgp.WrapError(err, "isOldInput") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *redirectionMsg) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 4 + // string "key" + o = append(o, 0x84, 0xa3, 0x6b, 0x65, 0x79) + o = msgp.AppendString(o, z.key) + // string "value" + o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) + o = msgp.AppendString(o, z.value) + // string "level" + o = append(o, 0xa5, 0x6c, 0x65, 0x76, 0x65, 0x6c) + o = msgp.AppendUint8(o, z.level) + // string "isOldInput" + o = append(o, 0xaa, 0x69, 0x73, 0x4f, 0x6c, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74) + o = msgp.AppendBool(o, z.isOldInput) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *redirectionMsg) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "key": + z.key, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "key") + return + } + case "value": + z.value, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "value") + return + } + case "level": + z.level, bts, err = msgp.ReadUint8Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "level") + return + } + case "isOldInput": + z.isOldInput, bts, err = msgp.ReadBoolBytes(bts) + if err != nil { + err = msgp.WrapError(err, "isOldInput") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *redirectionMsg) Msgsize() (s int) { + s = 1 + 4 + msgp.StringPrefixSize + len(z.key) + 6 + msgp.StringPrefixSize + len(z.value) + 6 + msgp.Uint8Size + 11 + msgp.BoolSize + return +} + +// DecodeMsg implements msgp.Decodable +func (z *redirectionMsgs) DecodeMsg(dc *msgp.Reader) (err error) { + var zb0002 uint32 + zb0002, err = dc.ReadArrayHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + if cap((*z)) >= int(zb0002) { + (*z) = (*z)[:zb0002] + } else { + (*z) = make(redirectionMsgs, zb0002) + } + for zb0001 := range *z { + err = (*z)[zb0001].DecodeMsg(dc) + if err != nil { + err = msgp.WrapError(err, zb0001) + return + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z redirectionMsgs) EncodeMsg(en *msgp.Writer) (err error) { + err = en.WriteArrayHeader(uint32(len(z))) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0003 := range z { + err = z[zb0003].EncodeMsg(en) + if err != nil { + err = msgp.WrapError(err, zb0003) + return + } + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z redirectionMsgs) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + o = msgp.AppendArrayHeader(o, uint32(len(z))) + for zb0003 := range z { + o, err = z[zb0003].MarshalMsg(o) + if err != nil { + err = msgp.WrapError(err, zb0003) + return + } + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *redirectionMsgs) UnmarshalMsg(bts []byte) (o []byte, err error) { + var zb0002 uint32 + zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + if cap((*z)) >= int(zb0002) { + (*z) = (*z)[:zb0002] + } else { + (*z) = make(redirectionMsgs, zb0002) + } + for zb0001 := range *z { + bts, err = (*z)[zb0001].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, zb0001) + return + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z redirectionMsgs) Msgsize() (s int) { + s = msgp.ArrayHeaderSize + for zb0003 := range z { + s += z[zb0003].Msgsize() + } + return +} diff --git a/redirect_msgp_test.go b/redirect_msgp_test.go new file mode 100644 index 00000000..c03d6ffc --- /dev/null +++ b/redirect_msgp_test.go @@ -0,0 +1,236 @@ +package fiber + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalredirectionMsg(t *testing.T) { + v := redirectionMsg{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgredirectionMsg(b *testing.B) { + v := redirectionMsg{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgredirectionMsg(b *testing.B) { + v := redirectionMsg{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalredirectionMsg(b *testing.B) { + v := redirectionMsg{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecoderedirectionMsg(t *testing.T) { + v := redirectionMsg{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecoderedirectionMsg Msgsize() is inaccurate") + } + + vn := redirectionMsg{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncoderedirectionMsg(b *testing.B) { + v := redirectionMsg{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecoderedirectionMsg(b *testing.B) { + v := redirectionMsg{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalredirectionMsgs(t *testing.T) { + v := redirectionMsgs{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgredirectionMsgs(b *testing.B) { + v := redirectionMsgs{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgredirectionMsgs(b *testing.B) { + v := redirectionMsgs{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalredirectionMsgs(b *testing.B) { + v := redirectionMsgs{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecoderedirectionMsgs(t *testing.T) { + v := redirectionMsgs{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecoderedirectionMsgs Msgsize() is inaccurate") + } + + vn := redirectionMsgs{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncoderedirectionMsgs(b *testing.B) { + v := redirectionMsgs{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecoderedirectionMsgs(b *testing.B) { + v := redirectionMsgs{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/redirect_test.go b/redirect_test.go index 95a084b4..7544aec0 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -5,16 +5,13 @@ package fiber import ( - "context" - "net" + "bytes" + "mime/multipart" "net/url" "testing" - "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttputil" ) // go test -run Test_Redirect_To @@ -40,16 +37,20 @@ func Test_Redirect_To_WithFlashMessages(t *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) - err := c.Redirect().With("success", "1").With("message", "test").To("http://example.com") + err := c.Redirect().With("success", "2").With("success", "1").With("message", "test", 2).To("http://example.com") require.NoError(t, err) require.Equal(t, 302, c.Response().StatusCode()) require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) - equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" - require.True(t, equal) + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing - c.Redirect().parseAndClearFlashMessages() - require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(t, err) + + require.Len(t, msgs, 2) + require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 2, isOldInput: false}) } // go test -run Test_Redirect_Route_WithParams @@ -88,6 +89,7 @@ func Test_Redirect_Route_WithParams_WithQueries(t *testing.T) { }) require.NoError(t, err) require.Equal(t, 302, c.Response().StatusCode()) + // analysis of query parameters with url parsing, since a map pass is always randomly ordered location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) require.NoError(t, err, "url.Parse(location)") @@ -183,11 +185,15 @@ func Test_Redirect_Back_WithFlashMessages(t *testing.T) { require.Equal(t, 302, c.Response().StatusCode()) require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation))) - equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" - require.True(t, equal) + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing - c.Redirect().parseAndClearFlashMessages() - require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(t, err) + + require.Len(t, msgs, 2) + require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } // go test -run Test_Redirect_Back_WithReferer @@ -222,43 +228,143 @@ func Test_Redirect_Route_WithFlashMessages(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed err := c.Redirect().With("success", "1").With("message", "test").Route("user") + + require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.NoError(t, err) require.Equal(t, 302, c.Response().StatusCode()) require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) - equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" - require.True(t, equal) + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing - c.Redirect().parseAndClearFlashMessages() - require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(t, err) + + require.Len(t, msgs, 2) + require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } // go test -run Test_Redirect_Route_WithOldInput func Test_Redirect_Route_WithOldInput(t *testing.T) { t.Parallel() - app := New() - app.Get("/user", func(c Ctx) error { - return c.SendString("user") - }).Name("user") + t.Run("Query", func(t *testing.T) { + t.Parallel() - c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") - c.Request().URI().SetQueryString("id=1&name=tom") - err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") - require.NoError(t, err) - require.Equal(t, 302, c.Response().StatusCode()) - require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - require.Contains(t, c.GetRespHeader(HeaderSetCookie), "fiber_flash=") - require.Contains(t, c.GetRespHeader(HeaderSetCookie), "success:1") - require.Contains(t, c.GetRespHeader(HeaderSetCookie), "message:test") + c.Request().URI().SetQueryString("id=1&name=tom") + err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") - require.Contains(t, c.GetRespHeader(HeaderSetCookie), ",old_input_data_id:1") - require.Contains(t, c.GetRespHeader(HeaderSetCookie), ",old_input_data_name:tom") + require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true}) - c.Redirect().parseAndClearFlashMessages() - require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.NoError(t, err) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) + + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing + + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(t, err) + + require.Len(t, msgs, 4) + require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true}) + require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true}) + }) + + t.Run("Form", func(t *testing.T) { + t.Parallel() + + app := New() + app.Post("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed + + c.Request().Header.Set(HeaderContentType, MIMEApplicationForm) + c.Request().SetBodyString("id=1&name=tom") + err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") + + require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true}) + + require.NoError(t, err) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) + + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing + + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(t, err) + + require.Len(t, msgs, 4) + require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true}) + require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true}) + }) + + t.Run("MultipartForm", func(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + require.NoError(t, writer.WriteField("id", "1")) + require.NoError(t, writer.WriteField("name", "tom")) + require.NoError(t, writer.Close()) + + c.Request().SetBody(body.Bytes()) + c.Request().Header.Set(HeaderContentType, writer.FormDataContentType()) + + err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") + + require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true}) + require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true}) + + require.NoError(t, err) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) + + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing + + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(t, err) + + require.Len(t, msgs, 4) + require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) + require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true}) + require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true}) + }) } // go test -run Test_Redirect_parseAndClearFlashMessages @@ -272,105 +378,83 @@ func Test_Redirect_parseAndClearFlashMessages(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + msgs := redirectionMsgs{ + { + key: "success", + value: "1", + }, + { + key: "message", + value: "test", + }, + { + key: "name", + value: "tom", + isOldInput: true, + }, + { + key: "id", + value: "1", + isOldInput: true, + }, + } + + val, err := msgs.MarshalMsg(nil) + require.NoError(t, err) + + c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) c.Redirect().parseAndClearFlashMessages() - require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.Equal(t, FlashMessage{ + Key: "success", + Value: "1", + Level: 0, + }, c.Redirect().Message("success")) - require.Equal(t, "1", c.Redirect().Message("success")) - require.Equal(t, "test", c.Redirect().Message("message")) - require.Equal(t, map[string]string{"success": "1", "message": "test"}, c.Redirect().Messages()) + require.Equal(t, FlashMessage{ + Key: "message", + Value: "test", + Level: 0, + }, c.Redirect().Message("message")) - require.Equal(t, "1", c.Redirect().OldInput("id")) - require.Equal(t, "tom", c.Redirect().OldInput("name")) - require.Equal(t, map[string]string{"id": "1", "name": "tom"}, c.Redirect().OldInputs()) -} + require.Equal(t, FlashMessage{}, c.Redirect().Message("not_message")) -// go test -run Test_Redirect_Request -func Test_Redirect_Request(t *testing.T) { - t.Parallel() - app := New() - - app.Get("/", func(c Ctx) error { - return c.Redirect().With("key", "value").With("key2", "value2").With("co\\:m\\,ma", "Fi\\:ber\\, v3").Route("name") - }) - - app.Get("/with-inputs", func(c Ctx) error { - return c.Redirect().WithInput().With("key", "value").With("key2", "value2").Route("name") - }) - - app.Get("/just-inputs", func(c Ctx) error { - return c.Redirect().WithInput().Route("name") - }) - - app.Get("/redirected", func(c Ctx) error { - return c.JSON(Map{ - "messages": c.Redirect().Messages(), - "inputs": c.Redirect().OldInputs(), - }) - }).Name("name") - - // Start test server - ln := fasthttputil.NewInmemoryListener() - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) - defer cancel() - - err := app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - GracefulContext: ctx, - }) - - assert.NoError(t, err) - }() - - // Test cases - testCases := []struct { - ExpectedErr error - URL string - CookieValue string - ExpectedBody string - ExpectedStatusCode int - }{ + require.Equal(t, []FlashMessage{ { - URL: "/", - CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", - ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, - ExpectedStatusCode: StatusOK, - ExpectedErr: nil, + Key: "success", + Value: "1", + Level: 0, }, { - URL: "/with-inputs?name=john&surname=doe", - CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", - ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, - ExpectedStatusCode: StatusOK, - ExpectedErr: nil, + Key: "message", + Value: "test", + Level: 0, + }, + }, c.Redirect().Messages()) + + require.Equal(t, OldInputData{ + Key: "id", + Value: "1", + }, c.Redirect().OldInput("id")) + + require.Equal(t, OldInputData{ + Key: "name", + Value: "tom", + }, c.Redirect().OldInput("name")) + + require.Equal(t, OldInputData{}, c.Redirect().OldInput("not_name")) + + require.Equal(t, []OldInputData{ + { + Key: "name", + Value: "tom", }, { - URL: "/just-inputs?name=john&surname=doe", - CookieValue: "old_input_data_name:john,old_input_data_surname:doe", - ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, - ExpectedStatusCode: StatusOK, - ExpectedErr: nil, + Key: "id", + Value: "1", }, - } - - for _, tc := range testCases { - client := &fasthttp.HostClient{ - Dial: func(_ string) (net.Conn, error) { - return ln.Dial() - }, - } - req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse() - req.SetRequestURI("http://example.com" + tc.URL) - req.Header.SetCookie(FlashCookieName, tc.CookieValue) - err := client.DoRedirects(req, resp, 1) - - require.NoError(t, err) - require.Equal(t, tc.ExpectedBody, string(resp.Body())) - require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) - } + }, c.Redirect().OldInputs()) } // go test -v -run=^$ -bench=Benchmark_Redirect_Route -benchmem -count=4 @@ -454,11 +538,35 @@ func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) { require.Equal(b, 302, c.Response().StatusCode()) require.Equal(b, "/user", string(c.Response().Header.Peek(HeaderLocation))) - equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" - require.True(b, equal) + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing - c.Redirect().parseAndClearFlashMessages() - require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + var msgs redirectionMsgs + _, err = msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(b, err) + + require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(b, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) +} + +var testredirectionMsgs = redirectionMsgs{ + { + key: "success", + value: "1", + }, + { + key: "message", + value: "test", + }, + { + key: "name", + value: "tom", + isOldInput: true, + }, + { + key: "id", + value: "1", + isOldInput: true, + }, } // go test -v -run=^$ -bench=Benchmark_Redirect_parseAndClearFlashMessages -benchmem -count=4 @@ -470,7 +578,10 @@ func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + val, err := testredirectionMsgs.MarshalMsg(nil) + require.NoError(b, err) + + c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) b.ReportAllocs() b.ResetTimer() @@ -479,15 +590,25 @@ func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) { c.Redirect().parseAndClearFlashMessages() } - require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) + require.Equal(b, FlashMessage{ + Key: "success", + Value: "1", + }, c.Redirect().Message("success")) - require.Equal(b, "1", c.Redirect().Message("success")) - require.Equal(b, "test", c.Redirect().Message("message")) - require.Equal(b, map[string]string{"success": "1", "message": "test"}, c.Redirect().Messages()) + require.Equal(b, FlashMessage{ + Key: "message", + Value: "test", + }, c.Redirect().Message("message")) - require.Equal(b, "1", c.Redirect().OldInput("id")) - require.Equal(b, "tom", c.Redirect().OldInput("name")) - require.Equal(b, map[string]string{"id": "1", "name": "tom"}, c.Redirect().OldInputs()) + require.Equal(b, OldInputData{ + Key: "id", + Value: "1", + }, c.Redirect().OldInput("id")) + + require.Equal(b, OldInputData{ + Key: "name", + Value: "tom", + }, c.Redirect().OldInput("name")) } // go test -v -run=^$ -bench=Benchmark_Redirect_processFlashMessages -benchmem -count=4 @@ -508,7 +629,15 @@ func Benchmark_Redirect_processFlashMessages(b *testing.B) { c.Redirect().processFlashMessages() } - require.Equal(b, "fiber_flash=success:1,message:test; path=/; SameSite=Lax", c.GetRespHeader(HeaderSetCookie)) + c.Context().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing + + var msgs redirectionMsgs + _, err := msgs.UnmarshalMsg([]byte(c.Cookies(FlashCookieName))) + require.NoError(b, err) + + require.Len(b, msgs, 2) + require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) + require.Contains(b, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } // go test -v -run=^$ -bench=Benchmark_Redirect_Messages -benchmem -count=4 @@ -520,10 +649,13 @@ func Benchmark_Redirect_Messages(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + val, err := testredirectionMsgs.MarshalMsg(nil) + require.NoError(b, err) + + c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) c.Redirect().parseAndClearFlashMessages() - var msgs map[string]string + var msgs []FlashMessage b.ReportAllocs() b.ResetTimer() @@ -532,8 +664,17 @@ func Benchmark_Redirect_Messages(b *testing.B) { msgs = c.Redirect().Messages() } - require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) - require.Equal(b, map[string]string{"success": "1", "message": "test"}, msgs) + require.Contains(b, msgs, FlashMessage{ + Key: "success", + Value: "1", + Level: 0, + }) + + require.Contains(b, msgs, FlashMessage{ + Key: "message", + Value: "test", + Level: 0, + }) } // go test -v -run=^$ -bench=Benchmark_Redirect_OldInputs -benchmem -count=4 @@ -545,10 +686,13 @@ func Benchmark_Redirect_OldInputs(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + val, err := testredirectionMsgs.MarshalMsg(nil) + require.NoError(b, err) + + c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) c.Redirect().parseAndClearFlashMessages() - var oldInputs map[string]string + var oldInputs []OldInputData b.ReportAllocs() b.ResetTimer() @@ -557,8 +701,15 @@ func Benchmark_Redirect_OldInputs(b *testing.B) { oldInputs = c.Redirect().OldInputs() } - require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) - require.Equal(b, map[string]string{"id": "1", "name": "tom"}, oldInputs) + require.Contains(b, oldInputs, OldInputData{ + Key: "name", + Value: "tom", + }) + + require.Contains(b, oldInputs, OldInputData{ + Key: "id", + Value: "1", + }) } // go test -v -run=^$ -bench=Benchmark_Redirect_Message -benchmem -count=4 @@ -570,10 +721,13 @@ func Benchmark_Redirect_Message(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + val, err := testredirectionMsgs.MarshalMsg(nil) + require.NoError(b, err) + + c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) c.Redirect().parseAndClearFlashMessages() - var msg string + var msg FlashMessage b.ReportAllocs() b.ResetTimer() @@ -582,8 +736,11 @@ func Benchmark_Redirect_Message(b *testing.B) { msg = c.Redirect().Message("message") } - require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) - require.Equal(b, "test", msg) + require.Equal(b, FlashMessage{ + Key: "message", + Value: "test", + Level: 0, + }, msg) } // go test -v -run=^$ -bench=Benchmark_Redirect_OldInput -benchmem -count=4 @@ -595,10 +752,13 @@ func Benchmark_Redirect_OldInput(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed - c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") + val, err := testredirectionMsgs.MarshalMsg(nil) + require.NoError(b, err) + + c.Request().Header.Set(HeaderCookie, "fiber_flash="+string(val)) c.Redirect().parseAndClearFlashMessages() - var input string + var input OldInputData b.ReportAllocs() b.ResetTimer() @@ -607,6 +767,8 @@ func Benchmark_Redirect_OldInput(b *testing.B) { input = c.Redirect().OldInput("name") } - require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) - require.Equal(b, "tom", input) + require.Equal(b, OldInputData{ + Key: "name", + Value: "tom", + }, input) }