diff --git a/ctx.go b/ctx.go index 52787a55..7f41048d 100644 --- a/ctx.go +++ b/ctx.go @@ -117,22 +117,22 @@ type ResFmt struct { // Accepts checks if the specified extensions or content types are acceptable. func (c *DefaultCtx) Accepts(offers ...string) string { - return getOffer(c.Get(HeaderAccept), acceptsOfferType, offers...) + return getOffer(c.fasthttp.Request.Header.Peek(HeaderAccept), acceptsOfferType, offers...) } // AcceptsCharsets checks if the specified charset is acceptable. func (c *DefaultCtx) AcceptsCharsets(offers ...string) string { - return getOffer(c.Get(HeaderAcceptCharset), acceptsOffer, offers...) + return getOffer(c.fasthttp.Request.Header.Peek(HeaderAcceptCharset), acceptsOffer, offers...) } // AcceptsEncodings checks if the specified encoding is acceptable. func (c *DefaultCtx) AcceptsEncodings(offers ...string) string { - return getOffer(c.Get(HeaderAcceptEncoding), acceptsOffer, offers...) + return getOffer(c.fasthttp.Request.Header.Peek(HeaderAcceptEncoding), acceptsOffer, offers...) } // AcceptsLanguages checks if the specified language is acceptable. func (c *DefaultCtx) AcceptsLanguages(offers ...string) string { - return getOffer(c.Get(HeaderAcceptLanguage), acceptsOffer, offers...) + return getOffer(c.fasthttp.Request.Header.Peek(HeaderAcceptLanguage), acceptsOffer, offers...) } // App returns the *App reference to the instance of the Fiber application diff --git a/helpers.go b/helpers.go index 1e1bc408..a07e3253 100644 --- a/helpers.go +++ b/helpers.go @@ -15,6 +15,7 @@ import ( "path/filepath" "reflect" "strings" + "sync" "time" "unsafe" @@ -33,9 +34,11 @@ type acceptedType struct { quality float64 specificity int order int - params string + params headerParams } +type headerParams map[string][]byte + // getTLSConfig returns a net listener's tls config func getTLSConfig(ln net.Listener) *tls.Config { // Get listener type @@ -225,7 +228,7 @@ func getGroupPath(prefix, path string) string { // acceptsOffer This function determines if an offer matches a given specification. // It checks if the specification ends with a '*' or if the offer has the prefix of the specification. // Returns true if the offer matches the specification, false otherwise. -func acceptsOffer(spec, offer, _ string) bool { +func acceptsOffer(spec, offer string, _ headerParams) bool { if len(spec) >= 1 && spec[len(spec)-1] == '*' { return true } else if strings.HasPrefix(spec, offer) { @@ -240,7 +243,7 @@ func acceptsOffer(spec, offer, _ string) bool { // It checks if the offer MIME type matches the specification MIME type or if the specification is of the form /* and the offer MIME type has the same MIME type. // It checks if the offer contains every parameter present in the specification. // Returns true if the offer type matches the specification, false otherwise. -func acceptsOfferType(spec, offerType, specParams string) bool { +func acceptsOfferType(spec, offerType string, specParams headerParams) bool { var offerMime, offerParams string if i := strings.IndexByte(offerType, ';'); i == -1 { @@ -286,35 +289,18 @@ func acceptsOfferType(spec, offerType, specParams string) bool { // For the sake of simplicity, we forgo this and compare the value as-is. Besides, it would // be highly unusual for a client to escape something other than a double quote or backslash. // See https://www.rfc-editor.org/rfc/rfc9110#name-parameters -func paramsMatch(specParamStr, offerParams string) bool { - if specParamStr == "" { +func paramsMatch(specParamStr headerParams, offerParams string) bool { + if len(specParamStr) == 0 { return true } - // Preprocess the spec params to more easily test - // for out-of-order parameters - specParams := make([][2]string, 0, 2) - forEachParameter(specParamStr, func(s1, s2 string) bool { - if s1 == "q" || s1 == "Q" { - return false - } - for i := range specParams { - if utils.EqualFold(s1, specParams[i][0]) { - specParams[i][1] = s2 - return false - } - } - specParams = append(specParams, [2]string{s1, s2}) - return true - }) - allSpecParamsMatch := true - for i := range specParams { + for specParam, specVal := range specParamStr { foundParam := false - forEachParameter(offerParams, func(offerParam, offerVal string) bool { - if utils.EqualFold(specParams[i][0], offerParam) { + fasthttp.VisitHeaderParams(utils.UnsafeBytes(offerParams), func(key, value []byte) bool { + if utils.EqualFold(specParam, string(key)) { foundParam = true - allSpecParamsMatch = utils.EqualFold(specParams[i][1], offerVal) + allSpecParamsMatch = utils.EqualFold(specVal, value) return false } return true @@ -323,6 +309,7 @@ func paramsMatch(specParamStr, offerParams string) bool { return false } } + return allSpecParamsMatch } @@ -364,12 +351,12 @@ func getSplicedStrList(headerValue string, dst []string) []string { // forEachMediaRange parses an Accept or Content-Type header, calling functor // on each media range. // See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields -func forEachMediaRange(header string, functor func(string)) { - hasDQuote := strings.IndexByte(header, '"') != -1 +func forEachMediaRange(header []byte, functor func([]byte)) { + hasDQuote := bytes.IndexByte(header, '"') != -1 for len(header) > 0 { n := 0 - header = strings.TrimLeft(header, " ") + header = bytes.TrimLeft(header, " ") quotes := 0 escaping := false @@ -395,7 +382,7 @@ func forEachMediaRange(header string, functor func(string)) { } } else { // Simple case. Just look for the next comma. - if n = strings.IndexByte(header, ','); n == -1 { + if n = bytes.IndexByte(header, ','); n == -1 { n = len(header) } } @@ -409,133 +396,22 @@ func forEachMediaRange(header string, functor func(string)) { } } -// forEachParamter parses a given parameter list, calling functor -// on each valid parameter. If functor returns false, we stop processing. -// It expects a leading ';'. -// See: https://www.rfc-editor.org/rfc/rfc9110#section-5.6.6 -// According to RFC-9110 2.4, it is up to our discretion whether -// to attempt to recover from errors in HTTP semantics. Therefor, -// we take the simple approach and exit early when a semantic error -// is detected in the header. -// -// parameter = parameter-name "=" parameter-value -// parameter-name = token -// parameter-value = ( token / quoted-string ) -// parameters = *( OWS ";" OWS [ parameter ] ) -func forEachParameter(params string, functor func(string, string) bool) { - for len(params) > 0 { - // eat OWS ";" OWS - params = strings.TrimLeft(params, " ") - if len(params) == 0 || params[0] != ';' { - return - } - params = strings.TrimLeft(params[1:], " ") - - n := 0 - - // make sure the parameter is at least one character long - if len(params) == 0 || !validHeaderFieldByte(params[n]) { - return - } - n++ - for n < len(params) && validHeaderFieldByte(params[n]) { - n++ - } - - // We should hit a '=' (that has more characters after it) - // If not, the parameter is invalid. - // param=foo - // ~~~~~^ - if n >= len(params)-1 || params[n] != '=' { - return - } - param := params[:n] - n++ - - if params[n] == '"' { - // Handle quoted strings and quoted-pairs (i.e., characters escaped with \ ) - // See: https://www.rfc-editor.org/rfc/rfc9110#section-5.6.4 - foundEndQuote := false - escaping := false - n++ - m := n - for ; n < len(params); n++ { - if params[n] == '"' && !escaping { - foundEndQuote = true - break - } - // Recipients that process the value of a quoted-string MUST handle - // a quoted-pair as if it were replaced by the octet following the backslash - escaping = params[n] == '\\' && !escaping - } - if !foundEndQuote { - // Not a valid parameter - return - } - if !functor(param, params[m:n]) { - return - } - n++ - } else if validHeaderFieldByte(params[n]) { - // Parse a normal value, which should just be a token. - m := n - n++ - for n < len(params) && validHeaderFieldByte(params[n]) { - n++ - } - if !functor(param, params[m:n]) { - return - } - } else { - // Value was invalid - return - } - params = params[n:] - } +// Pool for headerParams instances. The headerParams object *must* +// be cleared before being returned to the pool. +var headerParamPool = sync.Pool{ + New: func() any { + return make(headerParams) + }, } -// validHeaderFieldByte returns true if a valid tchar -// -// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / -// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA -// -// See: https://www.rfc-editor.org/rfc/rfc9110#section-5.6.2 -// Function copied from net/textproto: -// https://github.com/golang/go/blob/master/src/net/textproto/reader.go#L663 -func validHeaderFieldByte(c byte) bool { - // mask is a 128-bit bitmap with 1s for allowed bytes, - // so that the byte c can be tested with a shift and an and. - // If c >= 128, then 1<>64)) != 0 -} - -// getOffer return valid offer for header negotiation -func getOffer(header string, isAccepted func(spec, offer, specParams string) bool, offers ...string) string { +// getOffer return valid offer for header negotiation. +// Do not pass header using utils.UnsafeBytes - this can cause a panic due +// to the use of utils.ToLowerBytes. +func getOffer(header []byte, isAccepted func(spec, offer string, specParams headerParams) bool, offers ...string) string { if len(offers) == 0 { return "" } - if header == "" { + if len(header) == 0 { return offers[0] } @@ -544,36 +420,36 @@ func getOffer(header string, isAccepted func(spec, offer, specParams string) boo // Parse header and get accepted types with their quality and specificity // See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields - forEachMediaRange(header, func(accept string) { + forEachMediaRange(header, func(accept []byte) { order++ - spec, quality, params := accept, 1.0, "" + spec, quality := accept, 1.0 - if i := strings.IndexByte(accept, ';'); i != -1 { + var params headerParams + + if i := bytes.IndexByte(accept, ';'); i != -1 { spec = accept[:i] // The vast majority of requests will have only the q parameter with // no whitespace. Check this first to see if we can skip // the more involved parsing. - if strings.HasPrefix(accept[i:], ";q=") && strings.IndexByte(accept[i+3:], ';') == -1 { - if q, err := fasthttp.ParseUfloat([]byte(strings.TrimRight(accept[i+3:], " "))); err == nil { + if bytes.HasPrefix(accept[i:], []byte(";q=")) && bytes.IndexByte(accept[i+3:], ';') == -1 { + if q, err := fasthttp.ParseUfloat(bytes.TrimRight(accept[i+3:], " ")); err == nil { quality = q } } else { - hasParams := false - forEachParameter(accept[i:], func(param, val string) bool { - if param == "q" || param == "Q" { - if q, err := fasthttp.ParseUfloat([]byte(val)); err == nil { + params, _ = headerParamPool.Get().(headerParams) //nolint:errcheck // only contains headerParams + fasthttp.VisitHeaderParams(accept[i:], func(key, value []byte) bool { + if string(key) == "q" { + if q, err := fasthttp.ParseUfloat(value); err == nil { quality = q } return false } - hasParams = true + params[utils.UnsafeString(utils.ToLowerBytes(key))] = value return true }) - if hasParams { - params = accept[i:] - } } + // Skip this accept type if quality is 0.0 // See: https://www.rfc-editor.org/rfc/rfc9110#quality.values if quality == 0.0 { @@ -581,23 +457,23 @@ func getOffer(header string, isAccepted func(spec, offer, specParams string) boo } } - spec = strings.TrimRight(spec, " ") + spec = bytes.TrimRight(spec, " ") // Get specificity var specificity int // check for wildcard this could be a mime */* or a wildcard character * - if spec == "*/*" || spec == "*" { + if string(spec) == "*/*" || string(spec) == "*" { specificity = 1 - } else if strings.HasSuffix(spec, "/*") { + } else if bytes.HasSuffix(spec, []byte("/*")) { specificity = 2 - } else if strings.IndexByte(spec, '/') != -1 { + } else if bytes.IndexByte(spec, '/') != -1 { specificity = 3 } else { specificity = 4 } // Add to accepted types - acceptedTypes = append(acceptedTypes, acceptedType{spec, quality, specificity, order, params}) + acceptedTypes = append(acceptedTypes, acceptedType{utils.UnsafeString(spec), quality, specificity, order, params}) }) if len(acceptedTypes) > 1 { @@ -606,18 +482,30 @@ func getOffer(header string, isAccepted func(spec, offer, specParams string) boo } // Find the first offer that matches the accepted types + ret := "" + done := false for _, acceptedType := range acceptedTypes { - for _, offer := range offers { - if len(offer) == 0 { - continue + if !done { + for _, offer := range offers { + if offer == "" { + continue + } + if isAccepted(acceptedType.spec, offer, acceptedType.params) { + ret = offer + done = true + break + } } - if isAccepted(acceptedType.spec, offer, acceptedType.params) { - return offer + } + if acceptedType.params != nil { + for p := range acceptedType.params { + delete(acceptedType.params, p) } + headerParamPool.Put(acceptedType.params) } } - return "" + return ret } // sortAcceptedTypes sorts accepted types by quality and specificity, preserving order of equal elements diff --git a/helpers_fuzz_test.go b/helpers_fuzz_test.go index 2fce6475..ee617012 100644 --- a/helpers_fuzz_test.go +++ b/helpers_fuzz_test.go @@ -18,6 +18,6 @@ func FuzzUtilsGetOffer(f *testing.F) { f.Add(input) } f.Fuzz(func(_ *testing.T, spec string) { - getOffer(spec, acceptsOfferType, `application/json;version=1;v=1;foo=bar`, `text/plain;param="big fox"`) + getOffer([]byte(spec), acceptsOfferType, `application/json;version=1;v=1;foo=bar`, `text/plain;param="big fox"`) }) } diff --git a/helpers_test.go b/helpers_test.go index dac99930..ee0caa80 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -17,49 +17,49 @@ import ( func Test_Utils_GetOffer(t *testing.T) { t.Parallel() - require.Equal(t, "", getOffer("hello", acceptsOffer)) - require.Equal(t, "1", getOffer("", acceptsOffer, "1")) - require.Equal(t, "", getOffer("2", acceptsOffer, "1")) + require.Equal(t, "", getOffer([]byte("hello"), acceptsOffer)) + require.Equal(t, "1", getOffer([]byte(""), acceptsOffer, "1")) + require.Equal(t, "", getOffer([]byte("2"), acceptsOffer, "1")) - require.Equal(t, "", getOffer("", acceptsOfferType)) - require.Equal(t, "", getOffer("text/html", acceptsOfferType)) - require.Equal(t, "", getOffer("text/html", acceptsOfferType, "application/json")) - require.Equal(t, "", getOffer("text/html;q=0", acceptsOfferType, "text/html")) - require.Equal(t, "", getOffer("application/json, */*; q=0", acceptsOfferType, "image/png")) - require.Equal(t, "application/xml", getOffer("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", acceptsOfferType, "application/xml", "application/json")) - require.Equal(t, "text/html", getOffer("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", acceptsOfferType, "text/html")) - require.Equal(t, "application/pdf", getOffer("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000", acceptsOfferType, "application/pdf", "application/json")) - require.Equal(t, "application/pdf", getOffer("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000", acceptsOfferType, "application/pdf", "application/json")) - require.Equal(t, "text/plain;a=1", getOffer("text/plain;a=1", acceptsOfferType, "text/plain;a=1")) - require.Equal(t, "", getOffer("text/plain;a=1;b=2", acceptsOfferType, "text/plain;b=2")) + require.Equal(t, "", getOffer([]byte(""), acceptsOfferType)) + require.Equal(t, "", getOffer([]byte("text/html"), acceptsOfferType)) + require.Equal(t, "", getOffer([]byte("text/html"), acceptsOfferType, "application/json")) + require.Equal(t, "", getOffer([]byte("text/html;q=0"), acceptsOfferType, "text/html")) + require.Equal(t, "", getOffer([]byte("application/json, */*; q=0"), acceptsOfferType, "image/png")) + require.Equal(t, "application/xml", getOffer([]byte("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), acceptsOfferType, "application/xml", "application/json")) + require.Equal(t, "text/html", getOffer([]byte("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), acceptsOfferType, "text/html")) + require.Equal(t, "application/pdf", getOffer([]byte("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000"), acceptsOfferType, "application/pdf", "application/json")) + require.Equal(t, "application/pdf", getOffer([]byte("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000"), acceptsOfferType, "application/pdf", "application/json")) + require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain;a=1"), acceptsOfferType, "text/plain;a=1")) + require.Equal(t, "", getOffer([]byte("text/plain;a=1;b=2"), acceptsOfferType, "text/plain;b=2")) // Spaces, quotes, out of order params, and case insensitivity - require.Equal(t, "text/plain", getOffer("text/plain ", acceptsOfferType, "text/plain")) - require.Equal(t, "text/plain", getOffer("text/plain;q=0.4 ", acceptsOfferType, "text/plain")) - require.Equal(t, "text/plain", getOffer("text/plain;q=0.4 ;", acceptsOfferType, "text/plain")) - require.Equal(t, "text/plain", getOffer("text/plain;q=0.4 ; p=foo", acceptsOfferType, "text/plain")) - require.Equal(t, "text/plain;b=2;a=1", getOffer("text/plain ;a=1;b=2", acceptsOfferType, "text/plain;b=2;a=1")) - require.Equal(t, "text/plain;a=1", getOffer("text/plain; a=1 ", acceptsOfferType, "text/plain;a=1")) - require.Equal(t, `text/plain;a="1;b=2\",text/plain"`, getOffer(`text/plain;a="1;b=2\",text/plain";q=0.9`, acceptsOfferType, `text/plain;a=1;b=2`, `text/plain;a="1;b=2\",text/plain"`)) - require.Equal(t, "text/plain;A=CAPS", getOffer(`text/plain;a="caPs"`, acceptsOfferType, "text/plain;A=CAPS")) + require.Equal(t, "text/plain", getOffer([]byte("text/plain "), acceptsOfferType, "text/plain")) + require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 "), acceptsOfferType, "text/plain")) + require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 ;"), acceptsOfferType, "text/plain")) + require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 ; p=foo"), acceptsOfferType, "text/plain")) + require.Equal(t, "text/plain;b=2;a=1", getOffer([]byte("text/plain ;a=1;b=2"), acceptsOfferType, "text/plain;b=2;a=1")) + require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain; a=1 "), acceptsOfferType, "text/plain;a=1")) + require.Equal(t, `text/plain;a="1;b=2\",text/plain"`, getOffer([]byte(`text/plain;a="1;b=2\",text/plain";q=0.9`), acceptsOfferType, `text/plain;a=1;b=2`, `text/plain;a="1;b=2\",text/plain"`)) + require.Equal(t, "text/plain;A=CAPS", getOffer([]byte(`text/plain;a="caPs"`), acceptsOfferType, "text/plain;A=CAPS")) // Priority - require.Equal(t, "text/plain", getOffer("text/plain", acceptsOfferType, "text/plain", "text/plain;a=1")) - require.Equal(t, "text/plain;a=1", getOffer("text/plain", acceptsOfferType, "text/plain;a=1", "text/plain")) - require.Equal(t, "text/plain;a=1", getOffer("text/plain,text/plain;a=1", acceptsOfferType, "text/plain", "text/plain;a=1")) - require.Equal(t, "text/plain", getOffer("text/plain;q=0.899,text/plain;a=1;q=0.898", acceptsOfferType, "text/plain", "text/plain;a=1")) - require.Equal(t, "text/plain;a=1;b=2", getOffer("text/plain,text/plain;a=1,text/plain;a=1;b=2", acceptsOfferType, "text/plain", "text/plain;a=1", "text/plain;a=1;b=2")) + require.Equal(t, "text/plain", getOffer([]byte("text/plain"), acceptsOfferType, "text/plain", "text/plain;a=1")) + require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain"), acceptsOfferType, "text/plain;a=1", "", "text/plain")) + require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain,text/plain;a=1"), acceptsOfferType, "text/plain", "text/plain;a=1")) + require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.899,text/plain;a=1;q=0.898"), acceptsOfferType, "text/plain", "text/plain;a=1")) + require.Equal(t, "text/plain;a=1;b=2", getOffer([]byte("text/plain,text/plain;a=1,text/plain;a=1;b=2"), acceptsOfferType, "text/plain", "text/plain;a=1", "text/plain;a=1;b=2")) // Takes the last value specified - require.Equal(t, "text/plain;a=1;b=2", getOffer("text/plain;a=1;b=1;B=2", acceptsOfferType, "text/plain;a=1;b=1", "text/plain;a=1;b=2")) + require.Equal(t, "text/plain;a=1;b=2", getOffer([]byte("text/plain;a=1;b=1;B=2"), acceptsOfferType, "text/plain;a=1;b=1", "text/plain;a=1;b=2")) - require.Equal(t, "", getOffer("utf-8, iso-8859-1;q=0.5", acceptsOffer)) - require.Equal(t, "", getOffer("utf-8, iso-8859-1;q=0.5", acceptsOffer, "ascii")) - require.Equal(t, "utf-8", getOffer("utf-8, iso-8859-1;q=0.5", acceptsOffer, "utf-8")) - require.Equal(t, "iso-8859-1", getOffer("utf-8;q=0, iso-8859-1;q=0.5", acceptsOffer, "utf-8", "iso-8859-1")) + require.Equal(t, "", getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer)) + require.Equal(t, "", getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer, "ascii")) + require.Equal(t, "utf-8", getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer, "utf-8")) + require.Equal(t, "iso-8859-1", getOffer([]byte("utf-8;q=0, iso-8859-1;q=0.5"), acceptsOffer, "utf-8", "iso-8859-1")) - require.Equal(t, "deflate", getOffer("gzip, deflate", acceptsOffer, "deflate")) - require.Equal(t, "", getOffer("gzip, deflate;q=0", acceptsOffer, "deflate")) + require.Equal(t, "deflate", getOffer([]byte("gzip, deflate"), acceptsOffer, "deflate")) + require.Equal(t, "", getOffer([]byte("gzip, deflate;q=0"), acceptsOffer, "deflate")) } // go test -v -run=^$ -bench=Benchmark_Utils_GetOffer -benchmem -count=4 @@ -90,9 +90,6 @@ func Benchmark_Utils_GetOffer(b *testing.B) { offers: []string{"application/json;version=1;foo=bar"}, }, { - // 1 alloc: - // The implementation uses a slice of length 2 allocated on the stack, - // so a third parameters causes a heap allocation. description: "3 parameters", accept: "application/json; version=1; foo=bar; charset=utf-8", offers: []string{"application/json;version=1;foo=bar;charset=utf-8"}, @@ -142,175 +139,58 @@ func Benchmark_Utils_GetOffer(b *testing.B) { } for _, tc := range testCases { + accept := []byte(tc.accept) b.Run(tc.description, func(b *testing.B) { for n := 0; n < b.N; n++ { - getOffer(tc.accept, acceptsOfferType, tc.offers...) + getOffer(accept, acceptsOfferType, tc.offers...) } }) } } -func Test_Utils_ForEachParameter(t *testing.T) { - testCases := []struct { - description string - paramStr string - expectedParams [][]string - }{ - { - description: "empty input", - paramStr: ``, - }, - { - description: "no parameters", - paramStr: `; `, - }, - { - description: "naked equals", - paramStr: `; = `, - }, - { - description: "no value", - paramStr: `;s=`, - }, - { - description: "no name", - paramStr: `;=bar`, - }, - { - description: "illegal characters in name", - paramStr: `; foo@bar=baz`, - }, - { - description: "value starts with illegal characters", - paramStr: `; foo=@baz; param=val`, - }, - { - description: "unterminated quoted value", - paramStr: `; foo="bar`, - }, - { - description: "illegal character after value terminates parsing", - paramStr: `; foo=bar@baz; param=val`, - expectedParams: [][]string{ - {"foo", "bar"}, - }, - }, - { - description: "parses parameters", - paramStr: `; foo=bar; PARAM=BAZ`, - expectedParams: [][]string{ - {"foo", "bar"}, - {"PARAM", "BAZ"}, - }, - }, - { - description: "stops parsing when functor returns false", - paramStr: `; foo=bar; end=baz; extra=unparsed`, - expectedParams: [][]string{ - {"foo", "bar"}, - {"end", "baz"}, - }, - }, - { - description: "stops parsing when encountering a non-parameter string", - paramStr: `; foo=bar; gzip; param=baz`, - expectedParams: [][]string{ - {"foo", "bar"}, - }, - }, - { - description: "quoted string with escapes and special characters", - // Note: the sequence \\\" is effectively an escaped backslash \\ and - // an escaped double quote \" - paramStr: `;foo="20t\w,b\\\"b;s=k o"`, - expectedParams: [][]string{ - {"foo", `20t\w,b\\\"b;s=k o`}, - }, - }, - { - description: "complex", - paramStr: ` ; foo=1 ; bar="\"value\""; end="20tw,b\\\"b;s=k o" ; action=skip `, - expectedParams: [][]string{ - {"foo", "1"}, - {"bar", `\"value\"`}, - {"end", `20tw,b\\\"b;s=k o`}, - }, - }, - } - for _, tc := range testCases { - n := 0 - forEachParameter(tc.paramStr, func(p, v string) bool { - require.Less(t, n, len(tc.expectedParams), "Received more parameters than expected: "+p+"="+v) - require.Equal(t, tc.expectedParams[n][0], p, tc.description) - require.Equal(t, tc.expectedParams[n][1], v, tc.description) - n++ - - // Stop parsing at the first parameter called "end" - return p != "end" - }) - require.Len(t, tc.expectedParams, n, tc.description+": number of parameters differs") - } - // Check that we exited on the second parameter (bar) -} - -// go test -v -run=^$ -bench=Benchmark_Utils_ForEachParameter -benchmem -count=4 -func Benchmark_Utils_ForEachParameter(b *testing.B) { - for n := 0; n < b.N; n++ { - forEachParameter(` ; josua=1 ; vermant="20tw\",bob;sack o" ; version=1; foo=bar; `, func(_, _ string) bool { - return true - }) - } -} - func Test_Utils_ParamsMatch(t *testing.T) { testCases := []struct { description string - accept string + accept headerParams offer string match bool }{ { description: "empty accept and offer", - accept: "", + accept: nil, offer: "", match: true, }, { description: "accept is empty, offer has params", - accept: "", + accept: make(headerParams), offer: ";foo=bar", match: true, }, { description: "offer is empty, accept has params", - accept: ";foo=bar", + accept: headerParams{"foo": []byte("bar")}, offer: "", match: false, }, { description: "accept has extra parameters", - accept: ";foo=bar;a=1", + accept: headerParams{"foo": []byte("bar"), "a": []byte("1")}, offer: ";foo=bar", match: false, }, { description: "matches regardless of order", - accept: "; a=1; b=2", + accept: headerParams{"b": []byte("2"), "a": []byte("1")}, offer: ";b=2;a=1", match: true, }, { description: "case insensitive", - accept: ";ParaM=FoO", + accept: headerParams{"ParaM": []byte("FoO")}, offer: ";pAram=foO", match: true, }, - { - description: "ignores q", - accept: ";q=0.42", - offer: "", - match: true, - }, } for _, tc := range testCases { @@ -320,8 +200,13 @@ func Test_Utils_ParamsMatch(t *testing.T) { func Benchmark_Utils_ParamsMatch(b *testing.B) { var match bool + + specParams := headerParams{ + "appLe": []byte("orange"), + "param": []byte("foo"), + } for n := 0; n < b.N; n++ { - match = paramsMatch(`; appLe=orange; param="foo"`, `;param=foo; apple=orange`) + match = paramsMatch(specParams, `;param=foo; apple=orange`) } require.True(b, match) } @@ -330,7 +215,7 @@ func Test_Utils_AcceptsOfferType(t *testing.T) { testCases := []struct { description string spec string - specParams string + specParams headerParams offerType string accepts bool }{ @@ -349,14 +234,14 @@ func Test_Utils_AcceptsOfferType(t *testing.T) { { description: "params match", spec: "application/json", - specParams: `; format=foo; version=1`, + specParams: headerParams{"format": []byte("foo"), "version": []byte("1")}, offerType: "application/json;version=1;format=foo;q=0.1", accepts: true, }, { description: "spec has extra params", spec: "text/html", - specParams: "; charset=utf-8", + specParams: headerParams{"charset": []byte("utf-8")}, offerType: "text/html", accepts: false, }, @@ -369,14 +254,14 @@ func Test_Utils_AcceptsOfferType(t *testing.T) { { description: "ignores optional whitespace", spec: "application/json", - specParams: `;format=foo; version=1`, + specParams: headerParams{"format": []byte("foo"), "version": []byte("1")}, offerType: "application/json; version=1 ; format=foo ", accepts: true, }, { description: "ignores optional whitespace", spec: "application/json", - specParams: `;format="foo bar"; version=1`, + specParams: headerParams{"format": []byte("foo bar"), "version": []byte("1")}, offerType: `application/json;version="1";format="foo bar"`, accepts: true, }, @@ -448,7 +333,7 @@ func Test_Utils_SortAcceptedTypes(t *testing.T) { {spec: "image/*", quality: 1, specificity: 2, order: 8}, {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, - {spec: "application/json", quality: 0.999, specificity: 3, params: ";a=1", order: 11}, + {spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11}, } sortAcceptedTypes(&acceptedTypes) require.Equal(t, []acceptedType{ @@ -460,7 +345,7 @@ func Test_Utils_SortAcceptedTypes(t *testing.T) { {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, {spec: "image/*", quality: 1, specificity: 2, order: 8}, - {spec: "application/json", quality: 0.999, specificity: 3, params: ";a=1", order: 11}, + {spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11}, {spec: "application/json", quality: 0.999, specificity: 3, order: 3}, {spec: "text/*", quality: 0.5, specificity: 2, order: 1}, {spec: "*/*", quality: 0.1, specificity: 1, order: 2},