diff --git a/ctx.go b/ctx.go index 901b5174..86fae50d 100644 --- a/ctx.go +++ b/ctx.go @@ -260,31 +260,92 @@ func (c *Ctx) BaseURL() string { return c.baseURI } -// Body contains the raw body submitted in a POST request. +// BodyRaw contains the raw body submitted in a POST request. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. +func (c *Ctx) BodyRaw() []byte { + return c.fasthttp.Request.Body() +} + +func (c *Ctx) tryDecodeBodyInOrder( + originalBody *[]byte, + encodings []string, +) ([]byte, uint8, error) { + var ( + err error + body []byte + decodesRealized uint8 + ) + + for index, encoding := range encodings { + decodesRealized++ + switch encoding { + case StrGzip: + body, err = c.fasthttp.Request.BodyGunzip() + case StrBr, StrBrotli: + body, err = c.fasthttp.Request.BodyUnbrotli() + case StrDeflate: + body, err = c.fasthttp.Request.BodyInflate() + default: + decodesRealized-- + if len(encodings) == 1 { + body = c.fasthttp.Request.Body() + } + return body, decodesRealized, nil + } + + if err != nil { + return nil, decodesRealized, err + } + + // Only execute body raw update if it has a next iteration to try to decode + if index < len(encodings)-1 && decodesRealized > 0 { + if index == 0 { + tempBody := c.fasthttp.Request.Body() + *originalBody = make([]byte, len(tempBody)) + copy(*originalBody, tempBody) + } + c.fasthttp.Request.SetBodyRaw(body) + } + } + + return body, decodesRealized, nil +} + +// Body contains the raw body submitted in a POST request. +// This method will decompress the body if the 'Content-Encoding' header is provided. +// It returns the original (or decompressed) body data which is valid only within the handler. +// Don't store direct references to the returned data. +// If you need to keep the body's data later, make a copy or use the Immutable option. func (c *Ctx) Body() []byte { - var err error - var encoding string - var body []byte + var ( + err error + body, originalBody []byte + headerEncoding string + encodingOrder = []string{"", "", ""} + ) + // faster than peek c.Request().Header.VisitAll(func(key, value []byte) { if c.app.getString(key) == HeaderContentEncoding { - encoding = c.app.getString(value) + headerEncoding = c.app.getString(value) } }) - switch encoding { - case StrGzip: - body, err = c.fasthttp.Request.BodyGunzip() - case StrBr, StrBrotli: - body, err = c.fasthttp.Request.BodyUnbrotli() - case StrDeflate: - body, err = c.fasthttp.Request.BodyInflate() - default: - body = c.fasthttp.Request.Body() + // Split and get the encodings list, in order to attend the + // rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5 + encodingOrder = getSplicedStrList(headerEncoding, encodingOrder) + if len(encodingOrder) == 0 { + return c.fasthttp.Request.Body() } + var decodesRealized uint8 + body, decodesRealized, err = c.tryDecodeBodyInOrder(&originalBody, encodingOrder) + + // Ensure that the body will be the original + if originalBody != nil && decodesRealized > 0 { + c.fasthttp.Request.SetBodyRaw(originalBody) + } if err != nil { return []byte(err.Error()) } diff --git a/ctx_test.go b/ctx_test.go index 38e83ca9..e092fcd3 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "compress/gzip" + "compress/zlib" "context" "crypto/tls" "encoding/xml" @@ -323,47 +324,211 @@ func Test_Ctx_Body(t *testing.T) { utils.AssertEqual(t, []byte("john=doe"), c.Body()) } -// go test -run Test_Ctx_Body_With_Compression -func Test_Ctx_Body_With_Compression(t *testing.T) { - t.Parallel() +func Benchmark_Ctx_Body(b *testing.B) { + const input = "john=doe" + app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) - c.Request().Header.Set("Content-Encoding", "gzip") - var b bytes.Buffer - gz := gzip.NewWriter(&b) - _, err := gz.Write([]byte("john=doe")) - utils.AssertEqual(t, nil, err) - err = gz.Flush() - utils.AssertEqual(t, nil, err) - err = gz.Close() - utils.AssertEqual(t, nil, err) - c.Request().SetBody(b.Bytes()) - utils.AssertEqual(t, []byte("john=doe"), c.Body()) -} - -// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4 -func Benchmark_Ctx_Body_With_Compression(b *testing.B) { - app := New() - c := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(c) - c.Request().Header.Set("Content-Encoding", "gzip") - var buf bytes.Buffer - gz := gzip.NewWriter(&buf) - _, err := gz.Write([]byte("john=doe")) - utils.AssertEqual(b, nil, err) - err = gz.Flush() - utils.AssertEqual(b, nil, err) - err = gz.Close() - utils.AssertEqual(b, nil, err) - - c.Request().SetBody(buf.Bytes()) + c.Request().SetBody([]byte(input)) for i := 0; i < b.N; i++ { _ = c.Body() } - utils.AssertEqual(b, []byte("john=doe"), c.Body()) + utils.AssertEqual(b, []byte(input), c.Body()) +} + +// go test -run Test_Ctx_Body_With_Compression +func Test_Ctx_Body_With_Compression(t *testing.T) { + t.Parallel() + tests := []struct { + name string + contentEncoding string + body []byte + expectedBody []byte + }{ + { + name: "gzip", + contentEncoding: "gzip", + body: []byte("john=doe"), + expectedBody: []byte("john=doe"), + }, + { + name: "unsupported_encoding", + contentEncoding: "undefined", + body: []byte("keeps_ORIGINAL"), + expectedBody: []byte("keeps_ORIGINAL"), + }, + { + name: "gzip then unsupported", + contentEncoding: "gzip, undefined", + body: []byte("Go, be gzipped"), + expectedBody: []byte("Go, be gzipped"), + }, + { + name: "invalid_deflate", + contentEncoding: "gzip,deflate", + body: []byte("I'm not correctly compressed"), + expectedBody: []byte(zlib.ErrHeader.Error()), + }, + } + + for _, testObject := range tests { + tCase := testObject // Duplicate object to ensure it will be unique across all runs + t.Run(tCase.name, func(t *testing.T) { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.Set("Content-Encoding", tCase.contentEncoding) + + if strings.Contains(tCase.contentEncoding, "gzip") { + var b bytes.Buffer + gz := gzip.NewWriter(&b) + _, err := gz.Write(tCase.body) + if err != nil { + t.Fatal(err) + } + if err = gz.Flush(); err != nil { + t.Fatal(err) + } + if err = gz.Close(); err != nil { + t.Fatal(err) + } + tCase.body = b.Bytes() + } + + c.Request().SetBody(tCase.body) + body := c.Body() + utils.AssertEqual(t, tCase.expectedBody, body) + + // Check if body raw is the same as previous before decompression + utils.AssertEqual( + t, tCase.body, c.Request().Body(), + "Body raw must be the same as set before", + ) + }) + } +} + +// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4 +func Benchmark_Ctx_Body_With_Compression(b *testing.B) { + encodingErr := errors.New("failed to encoding data") + + var ( + compressGzip = func(data []byte) ([]byte, error) { + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + if _, err := writer.Write(data); err != nil { + return nil, encodingErr + } + if err := writer.Flush(); err != nil { + return nil, encodingErr + } + if err := writer.Close(); err != nil { + return nil, encodingErr + } + return buf.Bytes(), nil + } + compressDeflate = func(data []byte) ([]byte, error) { + var buf bytes.Buffer + writer := zlib.NewWriter(&buf) + if _, err := writer.Write(data); err != nil { + return nil, encodingErr + } + if err := writer.Flush(); err != nil { + return nil, encodingErr + } + if err := writer.Close(); err != nil { + return nil, encodingErr + } + return buf.Bytes(), nil + } + ) + compressionTests := []struct { + contentEncoding string + compressWriter func([]byte) ([]byte, error) + }{ + { + contentEncoding: "gzip", + compressWriter: compressGzip, + }, + { + contentEncoding: "gzip,invalid", + compressWriter: compressGzip, + }, + { + contentEncoding: "deflate", + compressWriter: compressDeflate, + }, + { + contentEncoding: "gzip,deflate", + compressWriter: func(data []byte) ([]byte, error) { + var ( + buf bytes.Buffer + writer interface { + io.WriteCloser + Flush() error + } + err error + ) + + // deflate + { + writer = zlib.NewWriter(&buf) + if _, err = writer.Write(data); err != nil { + return nil, encodingErr + } + if err = writer.Flush(); err != nil { + return nil, encodingErr + } + if err = writer.Close(); err != nil { + return nil, encodingErr + } + } + + data = make([]byte, buf.Len()) + copy(data, buf.Bytes()) + buf.Reset() + + // gzip + { + writer = gzip.NewWriter(&buf) + if _, err = writer.Write(data); err != nil { + return nil, encodingErr + } + if err = writer.Flush(); err != nil { + return nil, encodingErr + } + if err = writer.Close(); err != nil { + return nil, encodingErr + } + } + + return buf.Bytes(), nil + }, + }, + } + + for _, ct := range compressionTests { + b.Run(ct.contentEncoding, func(b *testing.B) { + app := New() + const input = "john=doe" + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.Request().Header.Set("Content-Encoding", ct.contentEncoding) + compressedBody, err := ct.compressWriter([]byte(input)) + utils.AssertEqual(b, nil, err) + + c.Request().SetBody(compressedBody) + for i := 0; i < b.N; i++ { + _ = c.Body() + } + + utils.AssertEqual(b, []byte(input), c.Body()) + }) + } } // go test -run Test_Ctx_BodyParser diff --git a/docs/api/ctx.md b/docs/api/ctx.md index aa48c609..1da3db97 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -57,13 +57,13 @@ Fiber provides similar functions for the other accept headers. // Accept-Language: en;q=0.8, nl, ru app.Get("/", func(c *fiber.Ctx) error { - c.AcceptsCharsets("utf-16", "iso-8859-1") + c.AcceptsCharsets("utf-16", "iso-8859-1") // "iso-8859-1" - c.AcceptsEncodings("compress", "br") + c.AcceptsEncodings("compress", "br") // "compress" - c.AcceptsLanguages("pt", "nl", "ru") + c.AcceptsLanguages("pt", "nl", "ru") // "nl" // ... }) @@ -171,6 +171,7 @@ app.Get("/", func(c *fiber.Ctx) error { ``` ## Bind + Add vars to default view var map binding to template engine. Variables are read by the Render method and may be overwritten. @@ -190,12 +191,12 @@ app.Get("/", func(c *fiber.Ctx) error { }) ``` -## Body +## BodyRaw Returns the raw request **body**. ```go title="Signature" -func (c *Ctx) Body() []byte +func (c *Ctx) BodyRaw() []byte ``` ```go title="Example" @@ -203,6 +204,26 @@ func (c *Ctx) Body() []byte app.Post("/", func(c *fiber.Ctx) error { // Get raw body from POST request: + return c.Send(c.BodyRaw()) // []byte("user=john") +}) +``` + +> _Returned value is only valid within the handler. Do not store any references. +> Make copies or use the_ [_**`Immutable`**_](ctx.md) _setting instead._ [_Read more..._](../#zero-allocation) + +## Body + +As per the header `Content-Encoding`, this method will try to perform a file decompression from the **body** bytes. In case no `Content-Encoding` header is sent, it will perform as [BodyRaw](#bodyraw). + +```go title="Signature" +func (c *Ctx) Body() []byte +``` + +```go title="Example" +// echo 'user=john' | gzip | curl -v -i --data-binary @- -H "Content-Encoding: gzip" http://localhost:8080 + +app.Post("/", func(c *fiber.Ctx) error { + // Decompress body from POST request based on the Content-Encoding and return the raw content: return c.Send(c.Body()) // []byte("user=john") }) ``` @@ -216,13 +237,13 @@ Binds the request body to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a JSON body with a field called Pass, you would use a struct field of `json:"pass"`. -| content-type | struct tag | -|---|---| -| `application/x-www-form-urlencoded` | form | -| `multipart/form-data` | form | -| `application/json` | json | -| `application/xml` | xml | -| `text/xml` | xml | +| content-type | struct tag | +| ----------------------------------- | ---------- | +| `application/x-www-form-urlencoded` | form | +| `multipart/form-data` | form | +| `application/json` | json | +| `application/xml` | xml | +| `text/xml` | xml | ```go title="Signature" func (c *Ctx) BodyParser(out interface{}) error @@ -693,6 +714,7 @@ app.Get("/", func(c *fiber.Ctx) error { ## IsFromLocal Returns true if request came from localhost + ```go title="Signature" func (c *Ctx) IsFromLocal() bool { ``` @@ -837,7 +859,7 @@ app.Post("/", func(c *fiber.Ctx) error { c.Location("http://example.com") c.Location("/foo/bar") - + return nil }) ``` @@ -1024,6 +1046,7 @@ app.Get("/user/:id", func(c *fiber.Ctx) error { This method is equivalent of using `atoi` with ctx.Params ## ParamsParser + This method is similar to BodyParser, but for path parameters. It is important to use the struct tag "params". For example, if you want to parse a path parameter with a field called Pass, you would use a struct field of params:"pass" ```go title="Signature" @@ -1034,7 +1057,7 @@ func (c *Ctx) ParamsParser(out interface{}) error // GET http://example.com/user/111 app.Get("/user/:id", func(c *fiber.Ctx) error { param := struct {ID uint `params:"id"`}{} - + c.ParamsParser(¶m) // "{"id": 111}" // ... @@ -1176,7 +1199,6 @@ app.Get("/", func(c *fiber.Ctx) error { This property is an object containing a property for each query boolean parameter in the route, you could pass an optional default value that will be returned if the query key does not exist. - :::caution Please note if that parameter is not in the request, false will be returned. If the parameter is not a boolean, it is still tried to be converted and usually returned as false. @@ -1232,12 +1254,10 @@ app.Get("/", func(c *fiber.Ctx) error { }) ``` - ## QueryInt This property is an object containing a property for each query integer parameter in the route, you could pass an optional default value that will be returned if the query key does not exist. - :::caution Please note if that parameter is not in the request, zero will be returned. If the parameter is not a number, it is still tried to be converted and usually returned as 1. @@ -1522,7 +1542,7 @@ func (c *Ctx) Route() *Route app.Get("/hello/:name", func(c *fiber.Ctx) error { r := c.Route() fmt.Println(r.Method, r.Path, r.Params, r.Handlers) - // GET /hello/:name handler [name] + // GET /hello/:name handler [name] // ... }) @@ -1768,7 +1788,7 @@ var timeConverter = func(value string) reflect.Value { customTime := fiber.ParserType{ Customtype: CustomTime{}, Converter: timeConverter, -} +} // Add setting to the Decoder fiber.SetParserDecoder(fiber.ParserConfig{ @@ -1804,7 +1824,6 @@ app.Get("/query", func(c *fiber.Ctx) error { ``` - ## SetUserContext Sets the user specified implementation for context interface. @@ -2020,7 +2039,7 @@ XML also sets the content header to **application/xml**. ::: ```go title="Signature" -func (c *Ctx) XML(data interface{}) error +func (c *Ctx) XML(data interface{}) error ``` ```go title="Example" diff --git a/helpers.go b/helpers.go index cfe31e0d..cc36f13e 100644 --- a/helpers.go +++ b/helpers.go @@ -269,6 +269,41 @@ func acceptsOfferType(spec, offerType string) bool { return false } +// getSplicedStrList function takes a string and a string slice as an argument, divides the string into different +// elements divided by ',' and stores these elements in the string slice. +// It returns the populated string slice as an output. +// +// If the given slice hasn't enough space, it will allocate more and return. +func getSplicedStrList(headerValue string, dst []string) []string { + if headerValue == "" { + return nil + } + + var ( + index int + character rune + lastElementEndsAt uint8 + insertIndex int + ) + for index, character = range headerValue + "$" { + if character == ',' || index == len(headerValue) { + if insertIndex >= len(dst) { + oldSlice := dst + dst = make([]string, len(dst)+(len(dst)>>1)+2) + copy(dst, oldSlice) + } + dst[insertIndex] = utils.TrimLeft(headerValue[lastElementEndsAt:index], ' ') + lastElementEndsAt = uint8(index + 1) + insertIndex++ + } + } + + if len(dst) > insertIndex { + dst = dst[:insertIndex] + } + return dst +} + // getOffer return valid offer for header negotiation func getOffer(header string, isAccepted func(spec, offer string) bool, offers ...string) string { if len(offers) == 0 { diff --git a/helpers_test.go b/helpers_test.go index 5ecab490..788b7a9a 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -107,6 +107,53 @@ func Benchmark_Utils_GetOffer(b *testing.B) { } } +func Test_Utils_GetSplicedStrList(t *testing.T) { + testCases := []struct { + description string + headerValue string + expectedList []string + }{ + { + description: "normal case", + headerValue: "gzip, deflate,br", + expectedList: []string{"gzip", "deflate", "br"}, + }, + { + description: "no matter the value", + headerValue: " gzip,deflate, br, zip", + expectedList: []string{"gzip", "deflate", "br", "zip"}, + }, + { + description: "headerValue is empty", + headerValue: "", + expectedList: nil, + }, + { + description: "has a comma without element", + headerValue: "gzip,", + expectedList: []string{"gzip", ""}, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + dst := make([]string, 10) + result := getSplicedStrList(tc.headerValue, dst) + utils.AssertEqual(t, tc.expectedList, result) + }) + } +} + +func Benchmark_Utils_GetSplicedStrList(b *testing.B) { + destination := make([]string, 5) + result := destination + const input = "deflate, gzip,br,brotli" + for n := 0; n < b.N; n++ { + result = getSplicedStrList(input, destination) + } + utils.AssertEqual(b, []string{"deflate", "gzip", "br", "brotli"}, result) +} + func Test_Utils_SortAcceptedTypes(t *testing.T) { t.Parallel() acceptedTypes := []acceptedType{ diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 20d683f4..db11cb4a 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -287,6 +287,7 @@ func Test_Session_Save_Expiration(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { + const sessionDuration = 5 * time.Second t.Parallel() // session store store := New() @@ -302,7 +303,7 @@ func Test_Session_Save_Expiration(t *testing.T) { sess.Set("name", "john") // expire this session in 5 seconds - sess.SetExpiry(time.Second * 5) + sess.SetExpiry(sessionDuration) // save session err = sess.Save() @@ -314,7 +315,7 @@ func Test_Session_Save_Expiration(t *testing.T) { utils.AssertEqual(t, "john", sess.Get("name")) // just to make sure the session has been expired - time.Sleep(time.Second * 5) + time.Sleep(sessionDuration + (10 * time.Millisecond)) // here you should get a new session sess, err = store.Get(ctx)