diff --git a/app.go b/app.go index 06c29e11..efdd7265 100644 --- a/app.go +++ b/app.go @@ -105,6 +105,10 @@ type App struct { server *fasthttp.Server // App config config Config + // Converts string to a byte slice + getBytes func(s string) (b []byte) + // Converts byte slice to a string + getString func(b []byte) string } // Config is a struct holding the server settings. @@ -364,7 +368,9 @@ func New(config ...Config) *App { }, }, // Create config - config: Config{}, + config: Config{}, + getBytes: utils.GetBytes, + getString: utils.GetString, } // Override config if provided if len(config) > 0 { @@ -394,7 +400,7 @@ func New(config ...Config) *App { app.config.CompressedFileSuffix = DefaultCompressedFileSuffix } if app.config.Immutable { - getBytes, getString = getBytesImmutable, getStringImmutable + app.getBytes, app.getString = getBytesImmutable, getStringImmutable } if app.config.ErrorHandler == nil { app.config.ErrorHandler = DefaultErrorHandler diff --git a/app_test.go b/app_test.go index e114d5a2..da638b8e 100644 --- a/app_test.go +++ b/app_test.go @@ -335,7 +335,7 @@ func Test_App_Use_UnescapedPath(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") // check the param result - utils.AssertEqual(t, "اختبار", getString(body)) + utils.AssertEqual(t, "اختبار", app.getString(body)) // with lowercase letters resp, err = app.Test(httptest.NewRequest(MethodGet, "/cr%C3%A9er/%D8%A7%D8%AE%D8%AA%D8%A8%D8%A7%D8%B1", nil)) @@ -370,7 +370,7 @@ func Test_App_Use_CaseSensitive(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") // check the detected path result - utils.AssertEqual(t, "/AbC", getString(body)) + utils.AssertEqual(t, "/AbC", app.getString(body)) } func Test_App_Add_Method_Test(t *testing.T) { diff --git a/client.go b/client.go index 2e6bf759..739e62f2 100644 --- a/client.go +++ b/client.go @@ -596,7 +596,7 @@ func (a *Agent) MultipartForm(args *Args) *Agent { if args != nil { args.VisitAll(func(key, value []byte) { - if err := a.mw.WriteField(getString(key), getString(value)); err != nil { + if err := a.mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)); err != nil { a.errs = append(a.errs, err) } }) @@ -785,7 +785,7 @@ func (a *Agent) Bytes() (code int, body []byte, errs []error) { func printDebugInfo(req *Request, resp *Response, w io.Writer) { msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr()) - _, _ = w.Write(getBytes(msg)) + _, _ = w.Write(utils.UnsafeBytes(msg)) _, _ = req.WriteTo(w) _, _ = resp.WriteTo(w) } @@ -794,7 +794,7 @@ func printDebugInfo(req *Request, resp *Response, w io.Writer) { func (a *Agent) String() (int, string, []error) { code, body, errs := a.Bytes() - return code, getString(body), errs + return code, utils.UnsafeString(body), errs } // Struct returns the status code, bytes body and errors of url. diff --git a/ctx.go b/ctx.go index 905f27e5..65fb2743 100644 --- a/ctx.go +++ b/ctx.go @@ -92,9 +92,9 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx { // Reset matched flag c.matched = false // Set paths - c.pathOriginal = getString(fctx.URI().PathOriginal()) + c.pathOriginal = app.getString(fctx.URI().PathOriginal()) // Set method - c.method = getString(fctx.Request.Header.Method()) + c.method = app.getString(fctx.Request.Header.Method()) c.methodINT = methodInt(c.method) // Attach *fasthttp.RequestCtx to ctx c.fasthttp = fctx @@ -195,7 +195,7 @@ func (c *Ctx) Append(field string, values ...string) { if len(values) == 0 { return } - h := getString(c.fasthttp.Response.Header.Peek(field)) + h := c.app.getString(c.fasthttp.Response.Header.Peek(field)) originalH := h for _, value := range values { if len(h) == 0 { @@ -216,7 +216,7 @@ func (c *Ctx) Attachment(filename ...string) { fname := filepath.Base(filename[0]) c.Type(filepath.Ext(fname)) - c.setCanonical(HeaderContentDisposition, `attachment; filename="`+quoteString(fname)+`"`) + c.setCanonical(HeaderContentDisposition, `attachment; filename="`+c.app.quoteString(fname)+`"`) return } c.setCanonical(HeaderContentDisposition, "attachment") @@ -339,7 +339,7 @@ func (c *Ctx) Cookie(cookie *Cookie) { // The returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (c *Ctx) Cookies(key string, defaultValue ...string) string { - return defaultString(getString(c.fasthttp.Request.Header.Cookie(key)), defaultValue) + return defaultString(c.app.getString(c.fasthttp.Request.Header.Cookie(key)), defaultValue) } // Download transfers the file from path as an attachment. @@ -353,7 +353,7 @@ func (c *Ctx) Download(file string, filename ...string) error { } else { fname = filepath.Base(file) } - c.setCanonical(HeaderContentDisposition, `attachment; filename="`+quoteString(fname)+`"`) + c.setCanonical(HeaderContentDisposition, `attachment; filename="`+c.app.quoteString(fname)+`"`) return c.SendFile(file) } @@ -385,7 +385,7 @@ func (c *Ctx) Format(body interface{}) error { case string: b = val case []byte: - b = getString(val) + b = c.app.getString(val) default: b = fmt.Sprintf("%v", val) } @@ -420,7 +420,7 @@ func (c *Ctx) FormFile(key string) (*multipart.FileHeader, error) { // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *Ctx) FormValue(key string, defaultValue ...string) string { - return defaultString(getString(c.fasthttp.FormValue(key)), defaultValue) + return defaultString(c.app.getString(c.fasthttp.FormValue(key)), defaultValue) } // Fresh returns true when the response is still “fresh” in the client's cache, @@ -449,16 +449,16 @@ func (c *Ctx) Fresh() bool { // if-none-match if noneMatch != "" && noneMatch != "*" { - var etag = getString(c.fasthttp.Response.Header.Peek(HeaderETag)) + var etag = c.app.getString(c.fasthttp.Response.Header.Peek(HeaderETag)) if etag == "" { return false } - if isEtagStale(etag, getBytes(noneMatch)) { + if c.app.isEtagStale(etag, c.app.getBytes(noneMatch)) { return false } if modifiedSince != "" { - var lastModified = getString(c.fasthttp.Response.Header.Peek(HeaderLastModified)) + var lastModified = c.app.getString(c.fasthttp.Response.Header.Peek(HeaderLastModified)) if lastModified != "" { lastModifiedTime, err := http.ParseTime(lastModified) if err != nil { @@ -480,14 +480,14 @@ func (c *Ctx) Fresh() bool { // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *Ctx) Get(key string, defaultValue ...string) string { - return defaultString(getString(c.fasthttp.Request.Header.Peek(key)), defaultValue) + return defaultString(c.app.getString(c.fasthttp.Request.Header.Peek(key)), defaultValue) } // Hostname contains the hostname derived from the Host HTTP header. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *Ctx) Hostname() string { - return getString(c.fasthttp.Request.URI().Host()) + return c.app.getString(c.fasthttp.Request.URI().Host()) } // IP returns the remote IP address of the request. @@ -509,10 +509,10 @@ func (c *Ctx) IPs() (ips []string) { for { commaPos = bytes.IndexByte(header, ',') if commaPos != -1 { - ips[i] = utils.Trim(getString(header[:commaPos]), ' ') + ips[i] = utils.Trim(c.app.getString(header[:commaPos]), ' ') header, i = header[commaPos+1:], i+1 } else { - ips[i] = utils.Trim(getString(header), ' ') + ips[i] = utils.Trim(c.app.getString(header), ' ') return } } @@ -565,7 +565,7 @@ func (c *Ctx) JSONP(data interface{}, callback ...string) error { cb = "callback" } - result = cb + "(" + getString(raw) + ");" + result = cb + "(" + c.app.getString(raw) + ");" c.setCanonical(HeaderXContentTypeOptions, "nosniff") c.fasthttp.Response.Header.SetContentType(MIMEApplicationJavaScriptCharsetUTF8) @@ -587,7 +587,7 @@ func (c *Ctx) Links(link ...string) { _, _ = bb.WriteString(`; rel="` + link[i] + `",`) } } - c.setCanonical(HeaderLink, utils.TrimRight(getString(bb.Bytes()), ',')) + c.setCanonical(HeaderLink, utils.TrimRight(c.app.getString(bb.Bytes()), ',')) bytebufferpool.Put(bb) } @@ -645,7 +645,7 @@ func (c *Ctx) Next() (err error) { // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (c *Ctx) OriginalURL() string { - return getString(c.fasthttp.Request.Header.RequestURI()) + return c.app.getString(c.fasthttp.Request.Header.RequestURI()) } // Params is used to get the route parameters. @@ -706,14 +706,14 @@ func (c *Ctx) Protocol() string { return // X-Forwarded- } else if bytes.HasPrefix(key, []byte("X-Forwarded-")) { if bytes.Equal(key, []byte(HeaderXForwardedProto)) { - scheme = getString(val) + scheme = c.app.getString(val) } else if bytes.Equal(key, []byte(HeaderXForwardedProtocol)) { - scheme = getString(val) + scheme = c.app.getString(val) } else if bytes.Equal(key, []byte(HeaderXForwardedSsl)) && bytes.Equal(val, []byte("on")) { scheme = "https" } } else if bytes.Equal(key, []byte(HeaderXUrlScheme)) { - scheme = getString(val) + scheme = c.app.getString(val) } }) return scheme @@ -725,7 +725,7 @@ func (c *Ctx) Protocol() string { // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (c *Ctx) Query(key string, defaultValue ...string) string { - return defaultString(getString(c.fasthttp.QueryArgs().Peek(key)), defaultValue) + return defaultString(c.app.getString(c.fasthttp.QueryArgs().Peek(key)), defaultValue) } // QueryParser binds the query string to a struct. @@ -882,7 +882,7 @@ func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error { return err } // Parse template - if tmpl, err = template.New("").Parse(getString(buf.Bytes())); err != nil { + if tmpl, err = template.New("").Parse(c.app.getString(buf.Bytes())); err != nil { return err } buf.Reset() @@ -1124,7 +1124,7 @@ func (c *Ctx) configDependentPaths() { if c.app.config.UnescapePath { c.pathBuffer = fasthttp.AppendUnquotedArg(c.pathBuffer[:0], c.pathBuffer) } - c.path = getString(c.pathBuffer) + c.path = c.app.getString(c.pathBuffer) // another path is specified which is for routing recognition only // use the path that was changed by the previous configuration flags @@ -1137,7 +1137,7 @@ func (c *Ctx) configDependentPaths() { if !c.app.config.StrictRouting && len(c.detectionPathBuffer) > 1 && c.detectionPathBuffer[len(c.detectionPathBuffer)-1] == '/' { c.detectionPathBuffer = utils.TrimRightBytes(c.detectionPathBuffer, '/') } - c.detectionPath = getString(c.detectionPathBuffer) + c.detectionPath = c.app.getString(c.detectionPathBuffer) // Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed, // since the first three characters area select a list of routes diff --git a/ctx_test.go b/ctx_test.go index 97936b27..cf0d4132 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -229,7 +229,7 @@ func Benchmark_Ctx_Append(b *testing.B) { c.Append("X-Custom-Header", "World") c.Append("X-Custom-Header", "Hello") } - utils.AssertEqual(b, "Hello, World", getString(c.Response().Header.Peek("X-Custom-Header"))) + utils.AssertEqual(b, "Hello, World", app.getString(c.Response().Header.Peek("X-Custom-Header"))) } // go test -run Test_Ctx_Attachment @@ -481,7 +481,7 @@ func Benchmark_Ctx_Cookie(b *testing.B) { Value: "Doe", }) } - utils.AssertEqual(b, "John=Doe; path=/; SameSite=Lax", getString(c.Response().Header.Peek("Set-Cookie"))) + utils.AssertEqual(b, "John=Doe; path=/; SameSite=Lax", app.getString(c.Response().Header.Peek("Set-Cookie"))) } // go test -run Test_Ctx_Cookies diff --git a/helpers.go b/helpers.go index a0f8d161..30b3101f 100644 --- a/helpers.go +++ b/helpers.go @@ -95,10 +95,10 @@ func readContent(rf io.ReaderFrom, name string) (n int64, err error) { } // quoteString escape special characters in a given string -func quoteString(raw string) string { +func (app *App) quoteString(raw string) string { bb := bytebufferpool.Get() // quoted := string(fasthttp.AppendQuotedArg(bb.B, getBytes(raw))) - quoted := getString(fasthttp.AppendQuotedArg(bb.B, getBytes(raw))) + quoted := app.getString(fasthttp.AppendQuotedArg(bb.B, app.getBytes(raw))) bytebufferpool.Put(bb) return quoted } @@ -272,7 +272,7 @@ func matchEtag(s string, etag string) bool { return false } -func isEtagStale(etag string, noneMatchBytes []byte) bool { +func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool { var start, end int // Adapted from: @@ -285,7 +285,7 @@ func isEtagStale(etag string, noneMatchBytes []byte) bool { end = i + 1 } case 0x2c: - if matchEtag(getString(noneMatchBytes[start:end]), etag) { + if matchEtag(app.getString(noneMatchBytes[start:end]), etag) { return false } start = i + 1 @@ -295,7 +295,7 @@ func isEtagStale(etag string, noneMatchBytes []byte) bool { } } - return !matchEtag(getString(noneMatchBytes[start:end]), etag) + return !matchEtag(app.getString(noneMatchBytes[start:end]), etag) } func parseAddr(raw string) (host, port string) { @@ -359,14 +359,10 @@ func (c *testConn) SetDeadline(_ time.Time) error { return nil } func (c *testConn) SetReadDeadline(_ time.Time) error { return nil } func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil } -// getString converts byte slice to a string without memory allocation. -var getString = utils.UnsafeString var getStringImmutable = func(b []byte) string { return string(b) } -// getBytes converts string to a byte slice without memory allocation. -var getBytes = utils.UnsafeBytes var getBytesImmutable = func(s string) (b []byte) { return []byte(s) } diff --git a/helpers_test.go b/helpers_test.go index 1600ce76..291d49a0 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -182,9 +182,9 @@ func Benchmark_Utils_Unescape(b *testing.B) { for n := 0; n < b.N; n++ { source := "/cr%C3%A9er" - pathBytes := getBytes(source) + pathBytes := utils.UnsafeBytes(source) pathBytes = fasthttp.AppendUnquotedArg(dst[:0], pathBytes) - unescaped = getString(pathBytes) + unescaped = utils.UnsafeString(pathBytes) } utils.AssertEqual(b, "/créer", unescaped) diff --git a/router.go b/router.go index c8765e3b..b842307c 100644 --- a/router.go +++ b/router.go @@ -325,7 +325,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router { PathRewrite: func(fctx *fasthttp.RequestCtx) []byte { path := fctx.Path() if len(path) >= prefixLen { - if isStar && getString(path[0:prefixLen]) == prefix { + if isStar && app.getString(path[0:prefixLen]) == prefix { path = append(path[0:0], '/') } else if len(path) > 0 && path[len(path)-1] != '/' { path = append(path[prefixLen:], '/') diff --git a/router_test.go b/router_test.go index 01509ba9..bebcc30d 100644 --- a/router_test.go +++ b/router_test.go @@ -43,7 +43,7 @@ func Test_Route_Match_SameLength(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, ":param", getString(body)) + utils.AssertEqual(t, ":param", app.getString(body)) // with param resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", nil)) @@ -52,7 +52,7 @@ func Test_Route_Match_SameLength(t *testing.T) { body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "test", getString(body)) + utils.AssertEqual(t, "test", app.getString(body)) } func Test_Route_Match_Star(t *testing.T) { @@ -68,7 +68,7 @@ func Test_Route_Match_Star(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "*", getString(body)) + utils.AssertEqual(t, "*", app.getString(body)) // with param resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", nil)) @@ -77,7 +77,7 @@ func Test_Route_Match_Star(t *testing.T) { body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "test", getString(body)) + utils.AssertEqual(t, "test", app.getString(body)) // without parameter route := Route{ @@ -114,7 +114,7 @@ func Test_Route_Match_Root(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "root", getString(body)) + utils.AssertEqual(t, "root", app.getString(body)) } func Test_Route_Match_Parser(t *testing.T) { @@ -132,7 +132,7 @@ func Test_Route_Match_Parser(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "bar", getString(body)) + utils.AssertEqual(t, "bar", app.getString(body)) // with star resp, err = app.Test(httptest.NewRequest(MethodGet, "/Foobar/test", nil)) @@ -141,7 +141,7 @@ func Test_Route_Match_Parser(t *testing.T) { body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "test", getString(body)) + utils.AssertEqual(t, "test", app.getString(body)) } func Test_Route_Match_Middleware(t *testing.T) { @@ -157,7 +157,7 @@ func Test_Route_Match_Middleware(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "*", getString(body)) + utils.AssertEqual(t, "*", app.getString(body)) // with param resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/bar/fasel", nil)) @@ -166,7 +166,7 @@ func Test_Route_Match_Middleware(t *testing.T) { body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "bar/fasel", getString(body)) + utils.AssertEqual(t, "bar/fasel", app.getString(body)) } func Test_Route_Match_UnescapedPath(t *testing.T) { @@ -182,7 +182,7 @@ func Test_Route_Match_UnescapedPath(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "test", getString(body)) + utils.AssertEqual(t, "test", app.getString(body)) // without special chars resp, err = app.Test(httptest.NewRequest(MethodGet, "/créer", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -208,7 +208,7 @@ func Test_Route_Match_Middleware_HasPrefix(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "middleware", getString(body)) + utils.AssertEqual(t, "middleware", app.getString(body)) } func Test_Route_Match_Middleware_Root(t *testing.T) { @@ -224,7 +224,7 @@ func Test_Route_Match_Middleware_Root(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "middleware", getString(body)) + utils.AssertEqual(t, "middleware", app.getString(body)) } func Test_Router_Register_Missing_Handler(t *testing.T) {