From 7eb9d255489192343c647f1d910abbb140705474 Mon Sep 17 00:00:00 2001 From: RW Date: Tue, 31 Dec 2024 16:56:18 +0100 Subject: [PATCH] Support Square Bracket Notation in Multipart Form data (#3268) * Feature Request: Support Square Bracket Notation in Multipart Form Data #3224 * Feature Request: Support Square Bracket Notation in Multipart Form Data #3224 --- ctx.go | 113 ++++++++++++++-------------------------------------- ctx_test.go | 42 +++++++++++++++++++ go.mod | 2 - go.sum | 6 --- helpers.go | 73 +++++++++++++++++++++++++++++++++ 5 files changed, 144 insertions(+), 92 deletions(-) diff --git a/ctx.go b/ctx.go index c7caada7..c0a4413a 100644 --- a/ctx.go +++ b/ctx.go @@ -406,28 +406,30 @@ func (c *Ctx) BodyParser(out interface{}) error { k := c.app.getString(key) v := c.app.getString(val) - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, bodyTag) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatParserData(out, data, bodyTag, k, v, c.app.config.EnableSplittingOnParsers, true) }) + if err != nil { + return err + } + return c.parseToStruct(bodyTag, out, data) } if strings.HasPrefix(ctype, MIMEMultipartForm) { - data, err := c.fasthttp.MultipartForm() + multipartForm, err := c.fasthttp.MultipartForm() if err != nil { return err } - return c.parseToStruct(bodyTag, out, data.Value) + + data := make(map[string][]string) + for key, values := range multipartForm.Value { + err = formatParserData(out, data, bodyTag, key, values, c.app.config.EnableSplittingOnParsers, true) + if err != nil { + return err + } + } + + return c.parseToStruct(bodyTag, out, data) } if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) { if err := xml.Unmarshal(c.Body(), out); err != nil { @@ -531,18 +533,7 @@ func (c *Ctx) CookieParser(out interface{}) error { k := c.app.getString(key) v := c.app.getString(val) - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, cookieTag) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatParserData(out, data, cookieTag, k, v, c.app.config.EnableSplittingOnParsers, true) }) if err != nil { return err @@ -1283,18 +1274,7 @@ func (c *Ctx) QueryParser(out interface{}) error { k := c.app.getString(key) v := c.app.getString(val) - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, queryTag) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatParserData(out, data, queryTag, k, v, c.app.config.EnableSplittingOnParsers, true) }) if err != nil { @@ -1304,61 +1284,26 @@ func (c *Ctx) QueryParser(out interface{}) error { return c.parseToStruct(queryTag, out, data) } -func parseParamSquareBrackets(k string) (string, error) { - bb := bytebufferpool.Get() - defer bytebufferpool.Put(bb) - - kbytes := []byte(k) - openBracketsCount := 0 - - for i, b := range kbytes { - if b == '[' { - openBracketsCount++ - if i+1 < len(kbytes) && kbytes[i+1] != ']' { - if err := bb.WriteByte('.'); err != nil { - return "", fmt.Errorf("failed to write: %w", err) - } - } - continue - } - - if b == ']' { - openBracketsCount-- - if openBracketsCount < 0 { - return "", errors.New("unmatched brackets") - } - continue - } - - if err := bb.WriteByte(b); err != nil { - return "", fmt.Errorf("failed to write: %w", err) - } - } - - if openBracketsCount > 0 { - return "", errors.New("unmatched brackets") - } - - return bb.String(), nil -} - // ReqHeaderParser binds the request header strings to a struct. func (c *Ctx) ReqHeaderParser(out interface{}) error { data := make(map[string][]string) + var err error + c.fasthttp.Request.Header.VisitAll(func(key, val []byte) { + if err != nil { + return + } + k := c.app.getString(key) v := c.app.getString(val) - if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, reqHeaderTag) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatParserData(out, data, reqHeaderTag, k, v, c.app.config.EnableSplittingOnParsers, false) }) + if err != nil { + return err + } + return c.parseToStruct(reqHeaderTag, out, data) } diff --git a/ctx_test.go b/ctx_test.go index 76ee3fa5..6a3998f7 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -610,6 +610,48 @@ func Test_Ctx_BodyParser(t *testing.T) { utils.AssertEqual(t, 2, len(cq.Data)) utils.AssertEqual(t, "john", cq.Data[0].Name) utils.AssertEqual(t, "doe", cq.Data[1].Name) + + t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Reset() + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + utils.AssertEqual(t, nil, writer.WriteField("data.0.name", "john")) + utils.AssertEqual(t, nil, writer.WriteField("data.1.name", "doe")) + utils.AssertEqual(t, nil, writer.Close()) + + c.Request().Header.SetContentType(writer.FormDataContentType()) + c.Request().SetBody(buf.Bytes()) + c.Request().Header.SetContentLength(len(c.Body())) + + cq := new(CollectionQuery) + utils.AssertEqual(t, nil, c.BodyParser(cq)) + utils.AssertEqual(t, len(cq.Data), 2) + utils.AssertEqual(t, "john", cq.Data[0].Name) + utils.AssertEqual(t, "doe", cq.Data[1].Name) + }) + + t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Reset() + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + utils.AssertEqual(t, nil, writer.WriteField("data[0][name]", "john")) + utils.AssertEqual(t, nil, writer.WriteField("data[1][name]", "doe")) + utils.AssertEqual(t, nil, writer.Close()) + + c.Request().Header.SetContentType(writer.FormDataContentType()) + c.Request().SetBody(buf.Bytes()) + c.Request().Header.SetContentLength(len(c.Body())) + + cq := new(CollectionQuery) + utils.AssertEqual(t, nil, c.BodyParser(cq)) + utils.AssertEqual(t, len(cq.Data), 2) + utils.AssertEqual(t, "john", cq.Data[0].Name) + utils.AssertEqual(t, "doe", cq.Data[1].Name) + }) } func Test_Ctx_ParamParser(t *testing.T) { diff --git a/go.mod b/go.mod index 38efaa55..4a3b38df 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,4 @@ require ( github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/mod v0.18.0 // indirect - golang.org/x/tools v0.22.0 // indirect ) diff --git a/go.sum b/go.sum index 06b81bad..67a47b01 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,6 @@ github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1Gsh github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/tinylib/msgp v1.1.3 h1:3giwAkmtaEDLSV0MdO1lDLuPgklgPzmk8H9+So2BVfA= -github.com/tinylib/msgp v1.1.3/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tinylib/msgp v1.2.5 h1:WeQg1whrXRFiZusidTQqzETkRpGjFjcIhW6uqWH09po= github.com/tinylib/msgp v1.2.5/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -25,11 +23,7 @@ github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1S github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= -golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= -golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= -golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= diff --git a/helpers.go b/helpers.go index dd8de15f..2896c2a4 100644 --- a/helpers.go +++ b/helpers.go @@ -7,6 +7,7 @@ package fiber import ( "bytes" "crypto/tls" + "errors" "fmt" "hash/crc32" "io" @@ -1151,3 +1152,75 @@ func IndexRune(str string, needle int32) bool { } return false } + +func parseParamSquareBrackets(k string) (string, error) { + bb := bytebufferpool.Get() + defer bytebufferpool.Put(bb) + + kbytes := []byte(k) + openBracketsCount := 0 + + for i, b := range kbytes { + if b == '[' { + openBracketsCount++ + if i+1 < len(kbytes) && kbytes[i+1] != ']' { + if err := bb.WriteByte('.'); err != nil { + return "", fmt.Errorf("failed to write: %w", err) + } + } + continue + } + + if b == ']' { + openBracketsCount-- + if openBracketsCount < 0 { + return "", errors.New("unmatched brackets") + } + continue + } + + if err := bb.WriteByte(b); err != nil { + return "", fmt.Errorf("failed to write: %w", err) + } + } + + if openBracketsCount > 0 { + return "", errors.New("unmatched brackets") + } + + return bb.String(), nil +} + +func formatParserData(out interface{}, data map[string][]string, aliasTag, key string, value interface{}, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay + var err error + if supportBracketNotation && strings.Contains(key, "[") { + key, err = parseParamSquareBrackets(key) + if err != nil { + return err + } + } + + switch v := value.(type) { + case string: + assignBindData(out, data, aliasTag, key, v, enableSplitting) + case []string: + for _, val := range v { + assignBindData(out, data, aliasTag, key, val, enableSplitting) + } + default: + return fmt.Errorf("unsupported value type: %T", value) + } + + return err +} + +func assignBindData(out interface{}, data map[string][]string, aliasTag, key, value string, enableSplitting bool) { //nolint:revive // it's okay + if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key, aliasTag) { + values := strings.Split(value, ",") + for i := 0; i < len(values); i++ { + data[key] = append(data[key], values[i]) + } + } else { + data[key] = append(data[key], value) + } +}