diff --git a/binder/mapping.go b/binder/mapping.go index 07af94a1..36821be0 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -1,13 +1,15 @@ package binder import ( + "errors" "reflect" "strings" "sync" - "github.com/gofiber/fiber/v3/internal/schema" "github.com/gofiber/utils/v2" "github.com/valyala/bytebufferpool" + + "github.com/gofiber/fiber/v3/internal/schema" ) // ParserConfig form decoder config for SetParserDecoder @@ -132,15 +134,24 @@ func parseParamSquareBrackets(k string) (string, error) { defer bytebufferpool.Put(bb) kbytes := []byte(k) + openBracketsCount := 0 for i, b := range kbytes { - if b == '[' && kbytes[i+1] != ']' { - if err := bb.WriteByte('.'); err != nil { - return "", err //nolint:wrapcheck // unnecessary to wrap it + if b == '[' { + openBracketsCount++ + if i+1 < len(kbytes) && kbytes[i+1] != ']' { + if err := bb.WriteByte('.'); err != nil { + return "", err //nolint:wrapcheck // unnecessary to wrap it + } } + continue } - if b == '[' || b == ']' { + if b == ']' { + openBracketsCount-- + if openBracketsCount < 0 { + return "", errors.New("unmatched brackets") + } continue } @@ -149,6 +160,10 @@ func parseParamSquareBrackets(k string) (string, error) { } } + if openBracketsCount > 0 { + return "", errors.New("unmatched brackets") + } + return bb.String(), nil } diff --git a/binder/mapping_test.go b/binder/mapping_test.go index aec91ff2..e6fc8146 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -1,6 +1,7 @@ package binder import ( + "errors" "reflect" "testing" @@ -29,3 +30,70 @@ func Test_EqualFieldType(t *testing.T) { require.True(t, equalFieldType(&user, reflect.Int, "AGE")) require.True(t, equalFieldType(&user, reflect.Int, "age")) } + +func Test_ParseParamSquareBrackets(t *testing.T) { + tests := []struct { + err error + input string + expected string + }{ + { + err: nil, + input: "foo[bar]", + expected: "foo.bar", + }, + { + err: nil, + input: "foo[bar][baz]", + expected: "foo.bar.baz", + }, + { + err: errors.New("unmatched brackets"), + input: "foo[bar", + expected: "", + }, + { + err: errors.New("unmatched brackets"), + input: "foo[bar][baz", + expected: "", + }, + { + err: errors.New("unmatched brackets"), + input: "foo]bar[", + expected: "", + }, + { + err: nil, + input: "foo[bar[baz]]", + expected: "foo.bar.baz", + }, + { + err: nil, + input: "", + expected: "", + }, + { + err: nil, + input: "[]", + expected: "", + }, + { + err: nil, + input: "foo[]", + expected: "foo", + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, err := parseParamSquareBrackets(tt.input) + if tt.err != nil { + require.Error(t, err) + require.EqualError(t, err, tt.err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +}