Merge branch 'master' into master

This commit is contained in:
hi019 2020-12-13 22:08:51 -05:00 committed by GitHub
commit 494474aebd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
98 changed files with 4178 additions and 1186 deletions

1
.github/README.md vendored
View File

@ -531,6 +531,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [sujit-baniya/fiber-boilerplate](https://github.com/sujit-baniya/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Contribute

View File

@ -532,6 +532,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Mitwirken

View File

@ -529,6 +529,8 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Contribuir

View File

@ -532,7 +532,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Contribuer
Si vous voulez nous remercier et/ou soutenir le développement actif de `Fiber`:

View File

@ -662,6 +662,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
</div>

View File

@ -532,6 +532,7 @@ Berikut adalah kumpulan _middlewares_ yang dibuat oleh komunitas Fiber, silahkan
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Berkontribusi

17
.github/README_ja.md vendored
View File

@ -106,9 +106,10 @@ func main() {
## ⚙️ インストール
Make sure you have Go installed ([download](https://golang.org/dl/)). Version `1.14` or higher is required.
Goがインストールされていることを確認してください ([ダウンロード](https://golang.org/dl/)). バージョン `1.14` またはそれ以上であることが必要です。
Initialize your project by creating a folder and then running `go mod init github.com/your/repo` ([learn more](https://blog.golang.org/using-go-modules)) inside the folder. Then install Fiber with the [`go get`](https://golang.org/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) command:
フォルダを作成し、フォルダ内で `go mod init github.com/your/repo` ([learn more](https://blog.golang.org/using-go-modules)) を実行してプロジェクトを初期化してください。その後、 Fiber を以下の [`go get`](https://golang.org/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) コマンドでインストールしてください。
```bash
go get -u github.com/gofiber/fiber/v2
@ -126,7 +127,7 @@ go get -u github.com/gofiber/fiber/v2
- [Template engines](https://github.com/gofiber/template)
- [WebSocket support](https://github.com/gofiber/websocket)
- [Rate Limiter](https://docs.gofiber.io/middleware#limiter)
- Available in [15 languages](https://docs.gofiber.io/)
- [15ヶ国語](https://docs.gofiber.io/)で利用可能
- [Fiber](https://docs.gofiber.io/)をもっと知る
## 💡 哲学
@ -142,8 +143,6 @@ Fiber は人気の高い Web フレームワークである Expressjs に**イ
以下に一般的な例をいくつか示します。他のコード例をご覧になりたい場合は、 [Recipes リポジトリ](https://github.com/gofiber/recipes)または[API ドキュメント](https://docs.gofiber.io)にアクセスしてください。
Listed below are some of the common examples. If you want to see more code examples , please visit our [Recipes repository](https://github.com/gofiber/recipes) or visit our hosted [API documentation](https://docs.gofiber.io).
#### 📖 [**Basic Routing**](https://docs.gofiber.io/#basic-routing)
```go
@ -485,7 +484,7 @@ func main() {
## 🧬 Internal Middleware
Here is a list of middleware that are included within the Fiber framework.
以下はFiberフレームワークに含まれるミドルウェアの一覧です。
| Middleware | Description |
| :------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
@ -506,7 +505,7 @@ Here is a list of middleware that are included within the Fiber framework.
## 🧬 External Middleware
List of externally hosted middleware modules and maintained by the [Fiber team](https://github.com/orgs/gofiber/people).
[Fiber team](https://github.com/orgs/gofiber/people) により管理・運用されているミドルウェアの一覧です。
| Middleware | Description |
| :------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
@ -522,6 +521,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
## 🌱 Third Party Middlewares
This is a list of middlewares that are created by the Fiber community, please create a PR if you want to see yours!
これらはFiberのコミュニティーによって作成されたミドルウェアの一覧です。もしあなたのミドルウェアを掲載したい場合はPRを作成してください
- [arsmn/fiber-casbin](https://github.com/arsmn/fiber-casbin)
- [arsmn/fiber-introspect](https://github.com/arsmn/fiber-introspect)
@ -536,6 +536,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 貢献する
@ -544,11 +545,11 @@ This is a list of middlewares that are created by the Fiber community, please cr
1. [GitHub Star](https://github.com/gofiber/fiber/stargazers)をつけてください 。
2. [あなたの Twitter で](https://twitter.com/intent/tweet?text=Fiber%20is%20an%20Express%20inspired%20%23web%20%23framework%20built%20on%20top%20of%20Fasthttp%2C%20the%20fastest%20HTTP%20engine%20for%20%23Go.%20Designed%20to%20ease%20things%20up%20for%20%23fast%20development%20with%20zero%20memory%20allocation%20and%20%23performance%20in%20mind%20%F0%9F%9A%80%20https%3A%2F%2Fgithub.com%2Fgofiber%2Ffiber)プロジェクトについてツイートしてください。
3. [Medium](https://medium.com/) 、 [Dev.to、](https://dev.to/)または個人のブログでレビューまたはチュートリアルを書いてください。
4. Support the project by donating a [cup of coffee](https://buymeacoff.ee/fenny).
4. [cup of coffee](https://buymeacoff.ee/fenny)の寄付でプロジェクトを支援しましょう。
## ☕ Supporters
Fiber is an open source project that runs on donations to pay the bills e.g. our domain name, gitbook, netlify and serverless hosting. If you want to support Fiber, you can ☕ [**buy a coffee here**](https://buymeacoff.ee/fenny).
Fiberはオープンソースプロジェクトで、寄付によってドメイン名やgitbook、 netlify、そしてサーバーレスホスティングなどの費用を賄っています。もしFiberを支援したければ ☕ [**こちらから**](https://buymeacoff.ee/fenny)。
| | User | Donation |
| :--------------------------------------------------------- | :----------------------------------------------- | :------- |

View File

@ -536,6 +536,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 기여

View File

@ -536,6 +536,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Bijdragen

View File

@ -530,6 +530,7 @@ Esta é uma lista de middlewares criados pela comunidade do Fiber, se quiser ter
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Contribuindo

View File

@ -532,6 +532,7 @@ func main() {
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Помощь проекту

View File

@ -596,6 +596,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 مساهمة

View File

@ -529,6 +529,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 Destek

View File

@ -531,6 +531,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 贡献

View File

@ -532,6 +532,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
## 👍 貢獻

44
app.go
View File

@ -31,7 +31,7 @@ import (
)
// Version of current fiber package
const Version = "2.2.0"
const Version = "2.2.5"
// Handler defines a function to serve HTTP requests.
type Handler = func(*Ctx) error
@ -39,27 +39,27 @@ type Handler = func(*Ctx) error
// Map is a shortcut for map[string]interface{}, useful for JSON returns
type Map map[string]interface{}
// Storage interface that is implemented by storage providers for different
// middleware packages like cache, limiter, session and csrf
// Storage interface for communicating with different database/key-value
// providers
type Storage interface {
// Get retrieves the value for the given key.
// If no value is not found it returns ErrNotExit error
// Get gets the value for the given key.
// It returns ErrNotFound if the storage does not contain the key.
Get(key string) ([]byte, error)
// Set stores the given value for the given key along with a
// time-to-live expiration value, 0 means live for ever
// The key must not be "" and the empty values are ignored.
// Empty key or value will be ignored without an error.
Set(key string, val []byte, ttl time.Duration) error
// Delete deletes the stored value for the given key.
// Deleting a non-existing key-value pair does NOT lead to an error.
// The key must not be "".
// Delete deletes the value for the given key.
// It returns no error if the storage does not contain the key,
Delete(key string) error
// Reset the storage
// Reset resets the storage and delete all keys.
Reset() error
// Close the storage
// Close closes the storage and will stop any running garbage
// collectors and open connections.
Close() error
}
@ -149,6 +149,8 @@ type Config struct {
ETag bool `json:"etag"`
// Max body size that the server accepts.
// -1 will decline any body size
//
// Default: 4 * 1024 * 1024
BodyLimit int `json:"body_limit"`
@ -352,7 +354,7 @@ func New(config ...Config) *App {
}
// Override default values
if app.config.BodyLimit <= 0 {
if app.config.BodyLimit == 0 {
app.config.BodyLimit = DefaultBodyLimit
}
if app.config.Concurrency <= 0 {
@ -373,13 +375,15 @@ func New(config ...Config) *App {
if app.config.ErrorHandler == nil {
app.config.ErrorHandler = DefaultErrorHandler
}
// Init app
app.init()
// Return app
return app
}
// Mount attaches another app instance as a subrouter along a routing path.
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
func (app *App) Mount(prefix string, fiber *App) Router {
@ -659,7 +663,7 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
type disableLogger struct{}
func (dl *disableLogger) Printf(format string, args ...interface{}) {
func (dl *disableLogger) Printf(_ string, _ ...interface{}) {
// fmt.Println(fmt.Sprintf(format, args...))
}
@ -735,7 +739,7 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
logo += " │ %s │\n"
logo += " │ %s │\n"
logo += " │ │\n"
logo += " │ Handlers %s Threads %s │\n"
logo += " │ Handlers %s Processes %s │\n"
logo += " │ Prefork .%s PID ....%s │\n"
logo += " └───────────────────────────────────────────────────┘"
logo += "%s"
@ -811,11 +815,16 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
isPrefork = "Enabled"
}
procs := strconv.Itoa(runtime.GOMAXPROCS(0))
if !app.config.Prefork {
procs = "1"
}
mainLogo := fmt.Sprintf(logo,
cBlack,
centerValue(" Fiber v"+Version, 49),
center(addr, 49),
value(strconv.Itoa(app.handlerCount), 14), value(strconv.Itoa(runtime.GOMAXPROCS(0)), 14),
value(strconv.Itoa(app.handlerCount), 14), value(procs, 12),
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
cReset,
)
@ -908,6 +917,5 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
out = colorable.NewNonColorable(os.Stdout)
}
fmt.Fprintln(out, output)
_, _ = fmt.Fprintln(out, output)
}

22
ctx.go
View File

@ -247,38 +247,42 @@ var decoderPool = &sync.Pool{New: func() interface{} {
// BodyParser binds the request body to a struct.
// It supports decoding the following content types based on the Content-Type header:
// application/json, application/xml, application/x-www-form-urlencoded, multipart/form-data
// If none of the content types above are matched, it will return a ErrUnprocessableEntity error
func (c *Ctx) BodyParser(out interface{}) error {
// Get decoder from pool
schemaDecoder := decoderPool.Get().(*schema.Decoder)
defer decoderPool.Put(schemaDecoder)
// Get content-type
ctype := getString(c.fasthttp.Request.Header.ContentType())
ctype := utils.ToLower(utils.UnsafeString(c.fasthttp.Request.Header.ContentType()))
// Parse body accordingly
if strings.HasPrefix(ctype, MIMEApplicationJSON) {
schemaDecoder.SetAliasTag("json")
return json.Unmarshal(c.fasthttp.Request.Body(), out)
} else if strings.HasPrefix(ctype, MIMEApplicationForm) {
}
if strings.HasPrefix(ctype, MIMEApplicationForm) {
schemaDecoder.SetAliasTag("form")
data := make(map[string][]string)
c.fasthttp.PostArgs().VisitAll(func(key []byte, val []byte) {
data[getString(key)] = append(data[getString(key)], getString(val))
data[utils.UnsafeString(key)] = append(data[utils.UnsafeString(key)], utils.UnsafeString(val))
})
return schemaDecoder.Decode(out, data)
} else if strings.HasPrefix(ctype, MIMEMultipartForm) {
}
if strings.HasPrefix(ctype, MIMEMultipartForm) {
schemaDecoder.SetAliasTag("form")
data, err := c.fasthttp.MultipartForm()
if err != nil {
return err
}
return schemaDecoder.Decode(out, data.Value)
} else if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
}
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
schemaDecoder.SetAliasTag("xml")
return xml.Unmarshal(c.fasthttp.Request.Body(), out)
}
// No suitable content type found
return fmt.Errorf("bodyparser: cannot parse content-type: %v", ctype)
return ErrUnprocessableEntity
}
// ClearCookie expires a specific cookie by key on the client side.
@ -946,7 +950,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
})
// Keep original path for mutable params
c.pathOriginal = utils.SafeString(c.pathOriginal)
c.pathOriginal = utils.CopyString(c.pathOriginal)
// Disable compression
if len(compress) <= 0 || !compress[0] {
// https://github.com/valyala/fasthttp/blob/master/fs.go#L46
@ -1017,7 +1021,7 @@ func (c *Ctx) SendStream(stream io.Reader, size ...int) error {
// Set sets the response's HTTP header field to the specified key, value.
func (c *Ctx) Set(key string, val string) {
c.fasthttp.Response.Header.Set(key, removeNewLines(val))
c.fasthttp.Response.Header.Set(key, val)
}
func (c *Ctx) setCanonical(key string, val string) {
@ -1098,7 +1102,7 @@ func (c *Ctx) WriteString(s string) (int, error) {
// XHR returns a Boolean property, that is true, if the request's X-Requested-With header field is XMLHttpRequest,
// indicating that the request was issued by a client library (such as jQuery).
func (c *Ctx) XHR() bool {
return utils.EqualsFold(utils.UnsafeBytes(c.Get(HeaderXRequestedWith)), []byte("xmlhttprequest"))
return utils.EqualFoldBytes(utils.UnsafeBytes(c.Get(HeaderXRequestedWith)), []byte("xmlhttprequest"))
}
// prettifyPath ...

5
go.mod
View File

@ -3,7 +3,6 @@ module github.com/gofiber/fiber/v2
go 1.14
require (
github.com/klauspost/compress v1.11.0 // indirect
github.com/valyala/fasthttp v1.17.0
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1
github.com/valyala/fasthttp v1.18.0
golang.org/x/sys v0.0.0-20201210223839-7e3030f88018
)

12
go.sum
View File

@ -2,25 +2,25 @@ github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDa
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg=
github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg=
github.com/valyala/fasthttp v1.17.0/go.mod h1:jjraHZVbKOXftJfsOYoAjaeygpj5hr8ermTRJNroD7A=
github.com/valyala/fasthttp v1.18.0 h1:IV0DdMlatq9QO1Cr6wGJPVW1sV1Q8HvZXAIcjorylyM=
github.com/valyala/fasthttp v1.18.0/go.mod h1:jjraHZVbKOXftJfsOYoAjaeygpj5hr8ermTRJNroD7A=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a h1:0R4NLDRDZX6JcmhJgXi5E4b8Wg84ihbmUKp/GvSPEzc=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg=
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1 h1:a/mKvvZr9Jcc8oKfcmgzyp7OwF73JPWsQLvH1z2Kxck=
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201210223839-7e3030f88018 h1:XKi8B/gRBuTZN1vU9gFsLMm6zVz5FSCDzm8JYACnjy8=
golang.org/x/sys v0.0.0-20201210223839-7e3030f88018/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -15,7 +15,7 @@ type Group struct {
prefix string
}
// Mount attaches another app instance as a subrouter along a routing path.
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
func (grp *Group) Mount(prefix string, fiber *App) Router {

View File

@ -94,29 +94,6 @@ func quoteString(raw string) string {
return quoted
}
// removeNewLines will replace `\r` and `\n` with an empty space
func removeNewLines(raw string) string {
start := 0
if start = strings.IndexByte(raw, '\r'); start == -1 {
if start = strings.IndexByte(raw, '\n'); start == -1 {
return raw
}
}
bb := bytebufferpool.Get()
buf := bb.Bytes()
buf = append(buf, raw...)
for i := start; i < len(buf); i++ {
if buf[i] != '\r' && buf[i] != '\n' {
continue
}
buf[i] = ' '
}
raw = utils.UnsafeString(buf)
bytebufferpool.Put(bb)
return raw
}
// Scan stack if other methods match the request
func methodExist(ctx *Ctx) (exist bool) {
for i := 0; i < len(intMethod); i++ {
@ -364,9 +341,9 @@ func (c *testConn) Close() error { return nil }
func (c *testConn) LocalAddr() net.Addr { return testAddr("local-addr") }
func (c *testConn) RemoteAddr() net.Addr { return testAddr("remote-addr") }
func (c *testConn) SetDeadline(t time.Time) error { return nil }
func (c *testConn) SetReadDeadline(t time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *testConn) SetDeadline(_ time.Time) error { return nil }
func (c *testConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil }
// getString converts byte slice to a string without memory allocation.
var getString = utils.UnsafeString

View File

@ -16,70 +16,6 @@ import (
"github.com/valyala/fasthttp"
)
// go test -v -run=^$ -bench=Benchmark_RemoveNewLines -benchmem -count=4
func Benchmark_RemoveNewLines(b *testing.B) {
withNL := "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n"
withoutNL := "foo Set-Cookie:%20SESSIONID=MaliciousValue "
expected := utils.SafeString(withoutNL)
var res string
b.Run("withoutNL", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = removeNewLines(withoutNL)
}
utils.AssertEqual(b, expected, res)
})
b.Run("withNL", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = removeNewLines(withNL)
}
utils.AssertEqual(b, expected, res)
})
}
// go test -v -run=RemoveNewLines_Bytes -count=3
func Test_RemoveNewLines_Bytes(t *testing.T) {
app := New()
t.Run("Not Status OK", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.SendString("Hello, World!")
c.Status(201)
setETag(c, false)
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
})
t.Run("No Body", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
setETag(c, false)
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
})
t.Run("Has HeaderIfNoneMatch", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.SendString("Hello, World!")
c.Request().Header.Set(HeaderIfNoneMatch, `"13-1831710635"`)
setETag(c, false)
utils.AssertEqual(t, 304, c.Response().StatusCode())
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
utils.AssertEqual(t, "", string(c.Response().Body()))
})
t.Run("No HeaderIfNoneMatch", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.SendString("Hello, World!")
setETag(c, false)
utils.AssertEqual(t, `"13-1831710635"`, string(c.Response().Header.Peek(HeaderETag)))
})
}
// go test -v -run=Test_Utils_ -count=3
func Test_Utils_ETag(t *testing.T) {
app := New()

View File

@ -0,0 +1,76 @@
# encoding/json [![GoDoc](https://godoc.org/github.com/segmentio/encoding/json?status.svg)](https://godoc.org/github.com/segmentio/encoding/json)
Go package offering a replacement implementation of the standard library's
[`encoding/json`](https://golang.org/pkg/encoding/json/) package, with much
better performance.
## Usage
The exported API of this package mirrors the standard library's
[`encoding/json`](https://golang.org/pkg/encoding/json/) package, the only
change needed to take advantage of the performance improvements is the import
path of the `json` package, from:
```go
import (
"encoding/json"
)
```
to
```go
import (
"github.com/segmentio/encoding/json"
)
```
One way to gain higher encoding throughput is to disable HTML escaping.
It allows the string encoding to use a much more efficient code path which
does not require parsing UTF-8 runes most of the time.
## Performance Improvements
The internal implementation uses a fair amount of unsafe operations (untyped
code, pointer arithmetic, etc...) to avoid using reflection as much as possible,
which is often the reason why serialization code has a large CPU and memory
footprint.
The package aims for zero unnecessary dynamic memory allocations and hot code
paths that are mostly free from calls into the reflect package.
## Compatibility with encoding/json
This package aims to be a drop-in replacement, therefore it is tested to behave
exactly like the standard library's package. However, there are still a few
missing features that have not been ported yet:
- Streaming decoder, currently the `Decoder` implementation offered by the
package does not support progressively reading values from a JSON array (unlike
the standard library). In our experience this is a very rare use-case, if you
need it you're better off sticking to the standard library, or spend a bit of
time implementing it in here ;)
Note that none of those features should result in performance degradations if
they were implemented in the package, and we welcome contributions!
## Trade-offs
As one would expect, we had to make a couple of trade-offs to achieve greater
performance than the standard library, but there were also features that we
did not want to give away.
Other open-source packages offering a reduced CPU and memory footprint usually
do so by designing a different API, or require code generation (therefore adding
complexity to the build process). These were not acceptable conditions for us,
as we were not willing to trade off developer productivity for better runtime
performance. To achieve this, we chose to exactly replicate the standard
library interfaces and behavior, which meant the package implementation was the
only area that we were able to work with. The internals of this package make
heavy use of unsafe pointer arithmetics and other performance optimizations,
and therefore are not as approachable as typical Go programs. Basically, we put
a bigger burden on maintainers to achieve better runtime cost without
sacrificing developer productivity.
For these reasons, we also don't believe that this code should be ported upstream
to the standard `encoding/json` package. The standard library has to remain
readable and approachable to maximize stability and maintainability, and make
projects like this one possible because a high quality reference implementation
already exists.

View File

@ -340,6 +340,24 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
encode: encoder.encodeMapStringRawMessage,
decode: decoder.decodeMapStringRawMessage,
}
case k == stringType && v == stringType:
return codec{
encode: encoder.encodeMapStringString,
decode: decoder.decodeMapStringString,
}
case k == stringType && v == stringsType:
return codec{
encode: encoder.encodeMapStringStringSlice,
decode: decoder.decodeMapStringStringSlice,
}
case k == stringType && v == boolType:
return codec{
encode: encoder.encodeMapStringBool,
decode: decoder.decodeMapStringBool,
}
}
kc := codec{}
@ -1035,6 +1053,7 @@ var (
numberType = reflect.TypeOf(json.Number(""))
stringType = reflect.TypeOf("")
stringsType = reflect.TypeOf([]string(nil))
bytesType = reflect.TypeOf(([]byte)(nil))
durationType = reflect.TypeOf(time.Duration(0))
timeType = reflect.TypeOf(time.Time{})
@ -1045,9 +1064,13 @@ var (
timePtrType = reflect.PtrTo(timeType)
rawMessagePtrType = reflect.PtrTo(rawMessageType)
sliceInterfaceType = reflect.TypeOf(([]interface{})(nil))
mapStringInterfaceType = reflect.TypeOf((map[string]interface{})(nil))
mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil))
sliceInterfaceType = reflect.TypeOf(([]interface{})(nil))
sliceStringType = reflect.TypeOf(([]interface{})(nil))
mapStringInterfaceType = reflect.TypeOf((map[string]interface{})(nil))
mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil))
mapStringStringType = reflect.TypeOf((map[string]string)(nil))
mapStringStringSliceType = reflect.TypeOf((map[string][]string)(nil))
mapStringBoolType = reflect.TypeOf((map[string]bool)(nil))
interfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
jsonMarshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()

View File

@ -535,7 +535,7 @@ func (d decoder) decodeArray(b []byte, p unsafe.Pointer, n int, size uintptr, t
if err != nil {
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = t.String() + e.Struct
e.Field = strconv.Itoa(i) + "." + e.Field
e.Field = d.prependField(strconv.Itoa(i), e.Field)
}
return b, err
}
@ -637,7 +637,7 @@ func (d decoder) decodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = t.String() + e.Struct
e.Field = strconv.Itoa(s.len) + "." + e.Field
e.Field = d.prependField(strconv.Itoa(s.len), e.Field)
}
return b, err
}
@ -716,7 +716,7 @@ func (d decoder) decodeMap(b []byte, p unsafe.Pointer, t, kt, vt reflect.Type, k
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = "map[" + kt.String() + "]" + vt.String() + "{" + e.Struct + "}"
e.Field = fmt.Sprint(k.Interface()) + "." + e.Field
e.Field = d.prependField(fmt.Sprint(k.Interface()), e.Field)
}
return b, err
}
@ -797,7 +797,7 @@ func (d decoder) decodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = mapStringInterfaceType.String() + e.Struct
e.Field = key + "." + e.Field
e.Field = d.prependField(key, e.Field)
}
return b, err
}
@ -878,7 +878,254 @@ func (d decoder) decodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = mapStringRawMessageType.String() + e.Struct
e.Field = key + "." + e.Field
e.Field = d.prependField(key, e.Field)
}
return b, err
}
m[key] = val
i++
}
}
func (d decoder) decodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) {
if hasNullPrefix(b) {
*(*unsafe.Pointer)(p) = nil
return b[4:], nil
}
if len(b) < 2 || b[0] != '{' {
return inputError(b, mapStringStringType)
}
i := 0
m := *(*map[string]string)(p)
if m == nil {
m = make(map[string]string, 64)
}
var err error
var key string
var val string
var input = b
b = b[1:]
for {
key = ""
val = ""
b = skipSpaces(b)
if len(b) != 0 && b[0] == '}' {
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(unsafe.Pointer(&m))
return b[1:], nil
}
if i != 0 {
if len(b) == 0 {
return b, syntaxError(b, "unexpected end of JSON input after object field value")
}
if b[0] != ',' {
return b, syntaxError(b, "expected ',' after object field value but found '%c'", b[0])
}
b = skipSpaces(b[1:])
}
if hasPrefix(b, "null") {
return b, syntaxError(b, "cannot decode object key string from 'null' value")
}
b, err = d.decodeString(b, unsafe.Pointer(&key))
if err != nil {
return objectKeyError(b, err)
}
b = skipSpaces(b)
if len(b) == 0 {
return b, syntaxError(b, "unexpected end of JSON input after object field key")
}
if b[0] != ':' {
return b, syntaxError(b, "expected ':' after object field key but found '%c'", b[0])
}
b = skipSpaces(b[1:])
b, err = d.decodeString(b, unsafe.Pointer(&val))
if err != nil {
if _, r, err := parseValue(input); err != nil {
return r, err
} else {
b = r
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = mapStringStringType.String() + e.Struct
e.Field = d.prependField(key, e.Field)
}
return b, err
}
m[key] = val
i++
}
}
func (d decoder) decodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) {
if hasNullPrefix(b) {
*(*unsafe.Pointer)(p) = nil
return b[4:], nil
}
if len(b) < 2 || b[0] != '{' {
return inputError(b, mapStringStringSliceType)
}
i := 0
m := *(*map[string][]string)(p)
if m == nil {
m = make(map[string][]string, 64)
}
var err error
var key string
var buf []string
var input = b
var stringSize = unsafe.Sizeof("")
b = b[1:]
for {
key = ""
buf = buf[:0]
b = skipSpaces(b)
if len(b) != 0 && b[0] == '}' {
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(unsafe.Pointer(&m))
return b[1:], nil
}
if i != 0 {
if len(b) == 0 {
return b, syntaxError(b, "unexpected end of JSON input after object field value")
}
if b[0] != ',' {
return b, syntaxError(b, "expected ',' after object field value but found '%c'", b[0])
}
b = skipSpaces(b[1:])
}
if hasPrefix(b, "null") {
return b, syntaxError(b, "cannot decode object key string from 'null' value")
}
b, err = d.decodeString(b, unsafe.Pointer(&key))
if err != nil {
return objectKeyError(b, err)
}
b = skipSpaces(b)
if len(b) == 0 {
return b, syntaxError(b, "unexpected end of JSON input after object field key")
}
if b[0] != ':' {
return b, syntaxError(b, "expected ':' after object field key but found '%c'", b[0])
}
b = skipSpaces(b[1:])
b, err = d.decodeSlice(b, unsafe.Pointer(&buf), stringSize, sliceStringType, decoder.decodeString)
if err != nil {
if _, r, err := parseValue(input); err != nil {
return r, err
} else {
b = r
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = mapStringStringType.String() + e.Struct
e.Field = d.prependField(key, e.Field)
}
return b, err
}
val := make([]string, len(buf))
copy(val, buf)
m[key] = val
i++
}
}
func (d decoder) decodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) {
if hasNullPrefix(b) {
*(*unsafe.Pointer)(p) = nil
return b[4:], nil
}
if len(b) < 2 || b[0] != '{' {
return inputError(b, mapStringBoolType)
}
i := 0
m := *(*map[string]bool)(p)
if m == nil {
m = make(map[string]bool, 64)
}
var err error
var key string
var val bool
var input = b
b = b[1:]
for {
key = ""
val = false
b = skipSpaces(b)
if len(b) != 0 && b[0] == '}' {
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(unsafe.Pointer(&m))
return b[1:], nil
}
if i != 0 {
if len(b) == 0 {
return b, syntaxError(b, "unexpected end of JSON input after object field value")
}
if b[0] != ',' {
return b, syntaxError(b, "expected ',' after object field value but found '%c'", b[0])
}
b = skipSpaces(b[1:])
}
if hasPrefix(b, "null") {
return b, syntaxError(b, "cannot decode object key string from 'null' value")
}
b, err = d.decodeString(b, unsafe.Pointer(&key))
if err != nil {
return objectKeyError(b, err)
}
b = skipSpaces(b)
if len(b) == 0 {
return b, syntaxError(b, "unexpected end of JSON input after object field key")
}
if b[0] != ':' {
return b, syntaxError(b, "expected ':' after object field key but found '%c'", b[0])
}
b = skipSpaces(b[1:])
b, err = d.decodeBool(b, unsafe.Pointer(&val))
if err != nil {
if _, r, err := parseValue(input); err != nil {
return r, err
} else {
b = r
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = mapStringStringType.String() + e.Struct
e.Field = d.prependField(key, e.Field)
}
return b, err
}
@ -968,7 +1215,7 @@ func (d decoder) decodeStruct(b []byte, p unsafe.Pointer, st *structType) ([]byt
}
if e, ok := err.(*UnmarshalTypeError); ok {
e.Struct = st.typ.String() + e.Struct
e.Field = string(k) + "." + e.Field
e.Field = d.prependField(string(k), e.Field)
}
return b, err
}
@ -1190,3 +1437,10 @@ func (d decoder) decodeTextUnmarshaler(b []byte, p unsafe.Pointer, t reflect.Typ
return b, &UnmarshalTypeError{Value: value, Type: reflect.PtrTo(t)}
}
func (d decoder) prependField(key, field string) string {
if field != "" {
return key + "." + field
}
return key
}

View File

@ -476,6 +476,7 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
b = append(b, ',')
}
// encodeString doesn't return errors so we ignore it here
b, _ = e.encodeString(b, unsafe.Pointer(&k))
b = append(b, ':')
@ -534,6 +535,224 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
return b, nil
}
func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string]string)(p)
if m == nil {
return append(b, "null"...), nil
}
if (e.flags & SortMapKeys) == 0 {
// Optimized code path when the program does not need the map keys to be
// sorted.
b = append(b, '{')
if len(m) != 0 {
var i = 0
for k, v := range m {
if i != 0 {
b = append(b, ',')
}
// encodeString never returns an error so we ignore it here
b, _ = e.encodeString(b, unsafe.Pointer(&k))
b = append(b, ':')
b, _ = e.encodeString(b, unsafe.Pointer(&v))
i++
}
}
b = append(b, '}')
return b, nil
}
s := mapslicePool.Get().(*mapslice)
if cap(s.elements) < len(m) {
s.elements = make([]element, 0, align(10, uintptr(len(m))))
}
for key, val := range m {
v := val
s.elements = append(s.elements, element{key: key, val: &v})
}
sort.Sort(s)
b = append(b, '{')
for i, elem := range s.elements {
if i != 0 {
b = append(b, ',')
}
// encodeString never returns an error so we ignore it here
b, _ = e.encodeString(b, unsafe.Pointer(&elem.key))
b = append(b, ':')
b, _ = e.encodeString(b, unsafe.Pointer(elem.val.(*string)))
}
for i := range s.elements {
s.elements[i] = element{}
}
s.elements = s.elements[:0]
mapslicePool.Put(s)
b = append(b, '}')
return b, nil
}
func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string][]string)(p)
if m == nil {
return append(b, "null"...), nil
}
var stringSize = unsafe.Sizeof("")
if (e.flags & SortMapKeys) == 0 {
// Optimized code path when the program does not need the map keys to be
// sorted.
b = append(b, '{')
if len(m) != 0 {
var err error
var i = 0
for k, v := range m {
if i != 0 {
b = append(b, ',')
}
b, _ = e.encodeString(b, unsafe.Pointer(&k))
b = append(b, ':')
b, err = e.encodeSlice(b, unsafe.Pointer(&v), stringSize, sliceStringType, encoder.encodeString)
if err != nil {
return b, err
}
i++
}
}
b = append(b, '}')
return b, nil
}
s := mapslicePool.Get().(*mapslice)
if cap(s.elements) < len(m) {
s.elements = make([]element, 0, align(10, uintptr(len(m))))
}
for key, val := range m {
v := val
s.elements = append(s.elements, element{key: key, val: &v})
}
sort.Sort(s)
var start = len(b)
var err error
b = append(b, '{')
for i, elem := range s.elements {
if i != 0 {
b = append(b, ',')
}
b, _ = e.encodeString(b, unsafe.Pointer(&elem.key))
b = append(b, ':')
b, err = e.encodeSlice(b, unsafe.Pointer(elem.val.(*[]string)), stringSize, sliceStringType, encoder.encodeString)
if err != nil {
break
}
}
for i := range s.elements {
s.elements[i] = element{}
}
s.elements = s.elements[:0]
mapslicePool.Put(s)
if err != nil {
return b[:start], err
}
b = append(b, '}')
return b, nil
}
func (e encoder) encodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string]bool)(p)
if m == nil {
return append(b, "null"...), nil
}
if (e.flags & SortMapKeys) == 0 {
// Optimized code path when the program does not need the map keys to be
// sorted.
b = append(b, '{')
if len(m) != 0 {
var i = 0
for k, v := range m {
if i != 0 {
b = append(b, ',')
}
// encodeString never returns an error so we ignore it here
b, _ = e.encodeString(b, unsafe.Pointer(&k))
if v {
b = append(b, ":true"...)
} else {
b = append(b, ":false"...)
}
i++
}
}
b = append(b, '}')
return b, nil
}
s := mapslicePool.Get().(*mapslice)
if cap(s.elements) < len(m) {
s.elements = make([]element, 0, align(10, uintptr(len(m))))
}
for key, val := range m {
s.elements = append(s.elements, element{key: key, val: val})
}
sort.Sort(s)
b = append(b, '{')
for i, elem := range s.elements {
if i != 0 {
b = append(b, ',')
}
// encodeString never returns an error so we ignore it here
b, _ = e.encodeString(b, unsafe.Pointer(&elem.key))
if elem.val.(bool) {
b = append(b, ":true"...)
} else {
b = append(b, ":false"...)
}
}
for i := range s.elements {
s.elements[i] = element{}
}
s.elements = s.elements[:0]
mapslicePool.Put(s)
b = append(b, '}')
return b, nil
}
func (e encoder) encodeStruct(b []byte, p unsafe.Pointer, st *structType) ([]byte, error) {
var start = len(b)
var err error

View File

@ -1,58 +0,0 @@
package json
import (
"math"
"strconv"
"testing"
)
func TestAppendInt(t *testing.T) {
var ints []int64
for i := 0; i < 64; i++ {
u := uint64(1) << i
ints = append(ints, int64(u-1), int64(u), int64(u+1), -int64(u))
}
var std [20]byte
var our [20]byte
for _, i := range ints {
expected := strconv.AppendInt(std[:], i, 10)
actual := appendInt(our[:], i)
if string(expected) != string(actual) {
t.Fatalf("appendInt(%d) = %v, expected = %v", i, string(actual), string(expected))
}
}
}
func benchStd(b *testing.B, n int64) {
var buf [20]byte
b.ResetTimer()
for i := 0; i < b.N; i++ {
strconv.AppendInt(buf[:0], n, 10)
}
}
func benchNew(b *testing.B, n int64) {
var buf [20]byte
b.ResetTimer()
for i := 0; i < b.N; i++ {
appendInt(buf[:0], n)
}
}
func BenchmarkAppendIntStd1(b *testing.B) {
benchStd(b, 1)
}
func BenchmarkAppendInt1(b *testing.B) {
benchNew(b, 1)
}
func BenchmarkAppendIntStdMinI64(b *testing.B) {
benchStd(b, math.MinInt64)
}
func BenchmarkAppendIntMinI64(b *testing.B) {
benchNew(b, math.MinInt64)
}

20
internal/gotiny/LICENSE Normal file
View File

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2016 zheng-ji.info
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,205 @@
package gotiny
import (
"fmt"
"reflect"
"sync"
"time"
"unsafe"
)
type decEng func(*Decoder, unsafe.Pointer) // 解码器
var (
rt2decEng = map[reflect.Type]decEng{
reflect.TypeOf((*bool)(nil)).Elem(): decBool,
reflect.TypeOf((*int)(nil)).Elem(): decInt,
reflect.TypeOf((*int8)(nil)).Elem(): decInt8,
reflect.TypeOf((*int16)(nil)).Elem(): decInt16,
reflect.TypeOf((*int32)(nil)).Elem(): decInt32,
reflect.TypeOf((*int64)(nil)).Elem(): decInt64,
reflect.TypeOf((*uint)(nil)).Elem(): decUint,
reflect.TypeOf((*uint8)(nil)).Elem(): decUint8,
reflect.TypeOf((*uint16)(nil)).Elem(): decUint16,
reflect.TypeOf((*uint32)(nil)).Elem(): decUint32,
reflect.TypeOf((*uint64)(nil)).Elem(): decUint64,
reflect.TypeOf((*uintptr)(nil)).Elem(): decUintptr,
reflect.TypeOf((*unsafe.Pointer)(nil)).Elem(): decPointer,
reflect.TypeOf((*float32)(nil)).Elem(): decFloat32,
reflect.TypeOf((*float64)(nil)).Elem(): decFloat64,
reflect.TypeOf((*complex64)(nil)).Elem(): decComplex64,
reflect.TypeOf((*complex128)(nil)).Elem(): decComplex128,
reflect.TypeOf((*[]byte)(nil)).Elem(): decBytes,
reflect.TypeOf((*string)(nil)).Elem(): decString,
reflect.TypeOf((*time.Time)(nil)).Elem(): decTime,
reflect.TypeOf((*struct{})(nil)).Elem(): decIgnore,
reflect.TypeOf(nil): decIgnore,
}
baseDecEngines = []decEng{
reflect.Invalid: decIgnore,
reflect.Bool: decBool,
reflect.Int: decInt,
reflect.Int8: decInt8,
reflect.Int16: decInt16,
reflect.Int32: decInt32,
reflect.Int64: decInt64,
reflect.Uint: decUint,
reflect.Uint8: decUint8,
reflect.Uint16: decUint16,
reflect.Uint32: decUint32,
reflect.Uint64: decUint64,
reflect.Uintptr: decUintptr,
reflect.UnsafePointer: decPointer,
reflect.Float32: decFloat32,
reflect.Float64: decFloat64,
reflect.Complex64: decComplex64,
reflect.Complex128: decComplex128,
reflect.String: decString,
}
decLock sync.RWMutex
)
func getDecEngine(rt reflect.Type) decEng {
decLock.RLock()
engine := rt2decEng[rt]
decLock.RUnlock()
if engine != nil {
return engine
}
decLock.Lock()
buildDecEngine(rt, &engine)
decLock.Unlock()
return engine
}
func buildDecEngine(rt reflect.Type, engPtr *decEng) {
engine, has := rt2decEng[rt]
if has {
*engPtr = engine
return
}
if _, engine = implementOtherSerializer(rt); engine != nil {
rt2decEng[rt] = engine
*engPtr = engine
return
}
kind := rt.Kind()
var eEng decEng
switch kind {
case reflect.Ptr:
et := rt.Elem()
defer buildDecEngine(et, &eEng)
engine = func(d *Decoder, p unsafe.Pointer) {
if d.decIsNotNil() {
if isNil(p) {
*(*unsafe.Pointer)(p) = unsafe.Pointer(reflect.New(et).Elem().UnsafeAddr())
}
eEng(d, *(*unsafe.Pointer)(p))
} else if !isNil(p) {
*(*unsafe.Pointer)(p) = nil
}
}
case reflect.Array:
l, et := rt.Len(), rt.Elem()
size := et.Size()
defer buildDecEngine(et, &eEng)
engine = func(d *Decoder, p unsafe.Pointer) {
for i := 0; i < l; i++ {
eEng(d, unsafe.Pointer(uintptr(p)+uintptr(i)*size))
}
}
case reflect.Slice:
et := rt.Elem()
size := et.Size()
defer buildDecEngine(et, &eEng)
engine = func(d *Decoder, p unsafe.Pointer) {
header := (*reflect.SliceHeader)(p)
if d.decIsNotNil() {
l := d.decLength()
if isNil(p) || header.Cap < l {
*header = reflect.SliceHeader{Data: reflect.MakeSlice(rt, l, l).Pointer(), Len: l, Cap: l}
} else {
header.Len = l
}
for i := 0; i < l; i++ {
eEng(d, unsafe.Pointer(header.Data+uintptr(i)*size))
}
} else if !isNil(p) {
*header = reflect.SliceHeader{}
}
}
case reflect.Map:
kt, vt := rt.Key(), rt.Elem()
skt, svt := reflect.SliceOf(kt), reflect.SliceOf(vt)
var kEng, vEng decEng
defer buildDecEngine(kt, &kEng)
defer buildDecEngine(vt, &vEng)
engine = func(d *Decoder, p unsafe.Pointer) {
if d.decIsNotNil() {
l := d.decLength()
var v reflect.Value
if isNil(p) {
v = reflect.MakeMapWithSize(rt, l)
*(*unsafe.Pointer)(p) = unsafe.Pointer(v.Pointer())
} else {
v = reflect.NewAt(rt, p).Elem()
}
keys, vals := reflect.MakeSlice(skt, l, l), reflect.MakeSlice(svt, l, l)
for i := 0; i < l; i++ {
key, val := keys.Index(i), vals.Index(i)
kEng(d, unsafe.Pointer(key.UnsafeAddr()))
vEng(d, unsafe.Pointer(val.UnsafeAddr()))
v.SetMapIndex(key, val)
}
} else if !isNil(p) {
*(*unsafe.Pointer)(p) = nil
}
}
case reflect.Struct:
fields, offs := getFieldType(rt, 0)
nf := len(fields)
fEngines := make([]decEng, nf)
defer func() {
for i := 0; i < nf; i++ {
buildDecEngine(fields[i], &fEngines[i])
}
}()
engine = func(d *Decoder, p unsafe.Pointer) {
for i := 0; i < len(fEngines) && i < len(offs); i++ {
fEngines[i](d, unsafe.Pointer(uintptr(p)+offs[i]))
}
}
case reflect.Interface:
engine = func(d *Decoder, p unsafe.Pointer) {
if d.decIsNotNil() {
name := ""
decString(d, unsafe.Pointer(&name))
et, has := name2type[name]
if !has {
//panic("unknown typ:" + name)
fmt.Println("[session] Register this type first with the `RegisterType` method.")
}
v := reflect.NewAt(rt, p).Elem()
var ev reflect.Value
if v.IsNil() || v.Elem().Type() != et {
ev = reflect.New(et).Elem()
} else {
ev = v.Elem()
}
getDecEngine(et)(d, getUnsafePointer(&ev))
v.Set(ev)
} else if !isNil(p) {
*(*unsafe.Pointer)(p) = nil
}
}
case reflect.Chan, reflect.Func:
//panic("not support " + rt.String() + " type")
default:
engine = baseDecEngines[kind]
}
rt2decEng[rt] = engine
*engPtr = engine
}

161
internal/gotiny/decbase.go Normal file
View File

@ -0,0 +1,161 @@
package gotiny
import (
"time"
"unsafe"
)
func (d *Decoder) decBool() (b bool) {
if d.boolBit == 0 {
d.boolBit = 1
d.boolPos = d.buf[d.index]
d.index++
}
b = d.boolPos&d.boolBit != 0
d.boolBit <<= 1
return
}
func (d *Decoder) decUint64() uint64 {
buf, i := d.buf, d.index
x := uint64(buf[i])
if x < 0x80 {
d.index++
return x
}
x1 := buf[i+1]
x += uint64(x1) << 7
if x1 < 0x80 {
d.index += 2
return x - 1<<7
}
x2 := buf[i+2]
x += uint64(x2) << 14
if x2 < 0x80 {
d.index += 3
return x - (1<<7 + 1<<14)
}
x3 := buf[i+3]
x += uint64(x3) << 21
if x3 < 0x80 {
d.index += 4
return x - (1<<7 + 1<<14 + 1<<21)
}
x4 := buf[i+4]
x += uint64(x4) << 28
if x4 < 0x80 {
d.index += 5
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28)
}
x5 := buf[i+5]
x += uint64(x5) << 35
if x5 < 0x80 {
d.index += 6
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35)
}
x6 := buf[i+6]
x += uint64(x6) << 42
if x6 < 0x80 {
d.index += 7
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42)
}
x7 := buf[i+7]
x += uint64(x7) << 49
if x7 < 0x80 {
d.index += 8
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42 + 1<<49)
}
d.index += 9
return x + uint64(buf[i+8])<<56 - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42 + 1<<49 + 1<<56)
}
func (d *Decoder) decUint16() uint16 {
buf, i := d.buf, d.index
x := uint16(buf[i])
if x < 0x80 {
d.index++
return x
}
x1 := buf[i+1]
x += uint16(x1) << 7
if x1 < 0x80 {
d.index += 2
return x - 1<<7
}
d.index += 3
return x + uint16(buf[i+2])<<14 - (1<<7 + 1<<14)
}
func (d *Decoder) decUint32() uint32 {
buf, i := d.buf, d.index
x := uint32(buf[i])
if x < 0x80 {
d.index++
return x
}
x1 := buf[i+1]
x += uint32(x1) << 7
if x1 < 0x80 {
d.index += 2
return x - 1<<7
}
x2 := buf[i+2]
x += uint32(x2) << 14
if x2 < 0x80 {
d.index += 3
return x - (1<<7 + 1<<14)
}
x3 := buf[i+3]
x += uint32(x3) << 21
if x3 < 0x80 {
d.index += 4
return x - (1<<7 + 1<<14 + 1<<21)
}
x4 := buf[i+4]
x += uint32(x4) << 28
d.index += 5
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28)
}
func (d *Decoder) decLength() int { return int(d.decUint32()) }
func (d *Decoder) decIsNotNil() bool { return d.decBool() }
func decIgnore(*Decoder, unsafe.Pointer) {}
func decBool(d *Decoder, p unsafe.Pointer) { *(*bool)(p) = d.decBool() }
func decInt(d *Decoder, p unsafe.Pointer) { *(*int)(p) = int(uint64ToInt64(d.decUint64())) }
func decInt8(d *Decoder, p unsafe.Pointer) { *(*int8)(p) = int8(d.buf[d.index]); d.index++ }
func decInt16(d *Decoder, p unsafe.Pointer) { *(*int16)(p) = uint16ToInt16(d.decUint16()) }
func decInt32(d *Decoder, p unsafe.Pointer) { *(*int32)(p) = uint32ToInt32(d.decUint32()) }
func decInt64(d *Decoder, p unsafe.Pointer) { *(*int64)(p) = uint64ToInt64(d.decUint64()) }
func decUint(d *Decoder, p unsafe.Pointer) { *(*uint)(p) = uint(d.decUint64()) }
func decUint8(d *Decoder, p unsafe.Pointer) { *(*uint8)(p) = d.buf[d.index]; d.index++ }
func decUint16(d *Decoder, p unsafe.Pointer) { *(*uint16)(p) = d.decUint16() }
func decUint32(d *Decoder, p unsafe.Pointer) { *(*uint32)(p) = d.decUint32() }
func decUint64(d *Decoder, p unsafe.Pointer) { *(*uint64)(p) = d.decUint64() }
func decUintptr(d *Decoder, p unsafe.Pointer) { *(*uintptr)(p) = uintptr(d.decUint64()) }
func decPointer(d *Decoder, p unsafe.Pointer) { *(*uintptr)(p) = uintptr(d.decUint64()) }
func decFloat32(d *Decoder, p unsafe.Pointer) { *(*float32)(p) = uint32ToFloat32(d.decUint32()) }
func decFloat64(d *Decoder, p unsafe.Pointer) { *(*float64)(p) = uint64ToFloat64(d.decUint64()) }
func decTime(d *Decoder, p unsafe.Pointer) { *(*time.Time)(p) = time.Unix(0, int64(d.decUint64())) }
func decComplex64(d *Decoder, p unsafe.Pointer) { *(*uint64)(p) = d.decUint64() }
func decComplex128(d *Decoder, p unsafe.Pointer) {
*(*uint64)(p) = d.decUint64()
*(*uint64)(unsafe.Pointer(uintptr(p) + ptr1Size)) = d.decUint64()
}
func decString(d *Decoder, p unsafe.Pointer) {
l, val := int(d.decUint32()), (*string)(p)
*val = string(d.buf[d.index : d.index+l])
d.index += l
}
func decBytes(d *Decoder, p unsafe.Pointer) {
bytes := (*[]byte)(p)
if d.decIsNotNil() {
l := int(d.decUint32())
*bytes = d.buf[d.index : d.index+l]
d.index += l
} else if !isNil(p) {
*bytes = nil
}
}

View File

@ -0,0 +1,97 @@
package gotiny
import (
"reflect"
"unsafe"
)
type Decoder struct {
buf []byte //buf
index int //下一个要使用的字节在buf中的下标
boolPos byte //下一次要读取的bool在buf中的下标,即buf[boolPos]
boolBit byte //下一次要读取的bool的buf[boolPos]中的bit位
engines []decEng //解码器集合
length int //解码器数量
}
func Unmarshal(buf []byte, is ...interface{}) int {
return NewDecoderWithPtr(is...).Decode(buf, is...)
}
func NewDecoderWithPtr(is ...interface{}) *Decoder {
l := len(is)
engines := make([]decEng, l)
for i := 0; i < l; i++ {
rt := reflect.TypeOf(is[i])
if rt.Kind() != reflect.Ptr {
panic("must a pointer type!")
}
engines[i] = getDecEngine(rt.Elem())
}
return &Decoder{
length: l,
engines: engines,
}
}
func NewDecoder(is ...interface{}) *Decoder {
l := len(is)
engines := make([]decEng, l)
for i := 0; i < l; i++ {
engines[i] = getDecEngine(reflect.TypeOf(is[i]))
}
return &Decoder{
length: l,
engines: engines,
}
}
func NewDecoderWithType(ts ...reflect.Type) *Decoder {
l := len(ts)
des := make([]decEng, l)
for i := 0; i < l; i++ {
des[i] = getDecEngine(ts[i])
}
return &Decoder{
length: l,
engines: des,
}
}
func (d *Decoder) reset() int {
index := d.index
d.index = 0
d.boolPos = 0
d.boolBit = 0
return index
}
// is is pointer of variable
func (d *Decoder) Decode(buf []byte, is ...interface{}) int {
d.buf = buf
engines := d.engines
for i := 0; i < len(engines) && i < len(is); i++ {
engines[i](d, (*[2]unsafe.Pointer)(unsafe.Pointer(&is[i]))[1])
}
return d.reset()
}
// ps is a unsafe.Pointer of the variable
func (d *Decoder) DecodePtr(buf []byte, ps ...unsafe.Pointer) int {
d.buf = buf
engines := d.engines
for i := 0; i < len(engines) && i < len(ps); i++ {
engines[i](d, ps[i])
}
return d.reset()
}
func (d *Decoder) DecodeValue(buf []byte, vs ...reflect.Value) int {
d.buf = buf
engines := d.engines
for i := 0; i < len(engines) && i < len(vs); i++ {
engines[i](d, unsafe.Pointer(vs[i].UnsafeAddr()))
}
return d.reset()
}

View File

@ -0,0 +1,196 @@
package gotiny
import (
"reflect"
"sync"
"time"
"unsafe"
)
type encEng func(*Encoder, unsafe.Pointer) //编码器
var (
rt2encEng = map[reflect.Type]encEng{
reflect.TypeOf((*bool)(nil)).Elem(): encBool,
reflect.TypeOf((*int)(nil)).Elem(): encInt,
reflect.TypeOf((*int8)(nil)).Elem(): encInt8,
reflect.TypeOf((*int16)(nil)).Elem(): encInt16,
reflect.TypeOf((*int32)(nil)).Elem(): encInt32,
reflect.TypeOf((*int64)(nil)).Elem(): encInt64,
reflect.TypeOf((*uint)(nil)).Elem(): encUint,
reflect.TypeOf((*uint8)(nil)).Elem(): encUint8,
reflect.TypeOf((*uint16)(nil)).Elem(): encUint16,
reflect.TypeOf((*uint32)(nil)).Elem(): encUint32,
reflect.TypeOf((*uint64)(nil)).Elem(): encUint64,
reflect.TypeOf((*uintptr)(nil)).Elem(): encUintptr,
reflect.TypeOf((*unsafe.Pointer)(nil)).Elem(): encPointer,
reflect.TypeOf((*float32)(nil)).Elem(): encFloat32,
reflect.TypeOf((*float64)(nil)).Elem(): encFloat64,
reflect.TypeOf((*complex64)(nil)).Elem(): encComplex64,
reflect.TypeOf((*complex128)(nil)).Elem(): encComplex128,
reflect.TypeOf((*[]byte)(nil)).Elem(): encBytes,
reflect.TypeOf((*string)(nil)).Elem(): encString,
reflect.TypeOf((*time.Time)(nil)).Elem(): encTime,
reflect.TypeOf((*struct{})(nil)).Elem(): encIgnore,
reflect.TypeOf(nil): encIgnore,
}
encEngines = [...]encEng{
reflect.Invalid: encIgnore,
reflect.Bool: encBool,
reflect.Int: encInt,
reflect.Int8: encInt8,
reflect.Int16: encInt16,
reflect.Int32: encInt32,
reflect.Int64: encInt64,
reflect.Uint: encUint,
reflect.Uint8: encUint8,
reflect.Uint16: encUint16,
reflect.Uint32: encUint32,
reflect.Uint64: encUint64,
reflect.Uintptr: encUintptr,
reflect.UnsafePointer: encPointer,
reflect.Float32: encFloat32,
reflect.Float64: encFloat64,
reflect.Complex64: encComplex64,
reflect.Complex128: encComplex128,
reflect.String: encString,
}
encLock sync.RWMutex
)
func UnusedUnixNanoEncodeTimeType() {
delete(rt2encEng, reflect.TypeOf((*time.Time)(nil)).Elem())
delete(rt2decEng, reflect.TypeOf((*time.Time)(nil)).Elem())
}
func getEncEngine(rt reflect.Type) encEng {
encLock.RLock()
engine := rt2encEng[rt]
encLock.RUnlock()
if engine != nil {
return engine
}
encLock.Lock()
buildEncEngine(rt, &engine)
encLock.Unlock()
return engine
}
func buildEncEngine(rt reflect.Type, engPtr *encEng) {
engine := rt2encEng[rt]
if engine != nil {
*engPtr = engine
return
}
if engine, _ = implementOtherSerializer(rt); engine != nil {
rt2encEng[rt] = engine
*engPtr = engine
return
}
kind := rt.Kind()
var eEng encEng
switch kind {
case reflect.Ptr:
defer buildEncEngine(rt.Elem(), &eEng)
engine = func(e *Encoder, p unsafe.Pointer) {
isNotNil := !isNil(p)
e.encIsNotNil(isNotNil)
if isNotNil {
eEng(e, *(*unsafe.Pointer)(p))
}
}
case reflect.Array:
et, l := rt.Elem(), rt.Len()
defer buildEncEngine(et, &eEng)
size := et.Size()
engine = func(e *Encoder, p unsafe.Pointer) {
for i := 0; i < l; i++ {
eEng(e, unsafe.Pointer(uintptr(p)+uintptr(i)*size))
}
}
case reflect.Slice:
et := rt.Elem()
size := et.Size()
defer buildEncEngine(et, &eEng)
engine = func(e *Encoder, p unsafe.Pointer) {
isNotNil := !isNil(p)
e.encIsNotNil(isNotNil)
if isNotNil {
header := (*reflect.SliceHeader)(p)
l := header.Len
e.encLength(l)
for i := 0; i < l; i++ {
eEng(e, unsafe.Pointer(header.Data+uintptr(i)*size))
}
}
}
case reflect.Map:
var kEng encEng
defer buildEncEngine(rt.Key(), &kEng)
defer buildEncEngine(rt.Elem(), &eEng)
engine = func(e *Encoder, p unsafe.Pointer) {
isNotNil := !isNil(p)
e.encIsNotNil(isNotNil)
if isNotNil {
v := reflect.NewAt(rt, p).Elem()
e.encLength(v.Len())
keys := v.MapKeys()
for i := 0; i < len(keys); i++ {
val := v.MapIndex(keys[i])
kEng(e, getUnsafePointer(&keys[i]))
eEng(e, getUnsafePointer(&val))
}
}
}
case reflect.Struct:
fields, offs := getFieldType(rt, 0)
nf := len(fields)
fEngines := make([]encEng, nf)
defer func() {
for i := 0; i < nf; i++ {
buildEncEngine(fields[i], &fEngines[i])
}
}()
engine = func(e *Encoder, p unsafe.Pointer) {
for i := 0; i < len(fEngines) && i < len(offs); i++ {
fEngines[i](e, unsafe.Pointer(uintptr(p)+offs[i]))
}
}
case reflect.Interface:
if rt.NumMethod() > 0 {
engine = func(e *Encoder, p unsafe.Pointer) {
isNotNil := !isNil(p)
e.encIsNotNil(isNotNil)
if isNotNil {
v := reflect.ValueOf(*(*interface {
M()
})(p))
et := v.Type()
e.encString(getNameOfType(et))
getEncEngine(et)(e, getUnsafePointer(&v))
}
}
} else {
engine = func(e *Encoder, p unsafe.Pointer) {
isNotNil := !isNil(p)
e.encIsNotNil(isNotNil)
if isNotNil {
v := reflect.ValueOf(*(*interface{})(p))
et := v.Type()
e.encString(getNameOfType(et))
getEncEngine(et)(e, getUnsafePointer(&v))
}
}
}
case reflect.Chan, reflect.Func:
//panic("not support " + rt.String() + " type")
default:
engine = encEngines[kind]
}
rt2encEng[rt] = engine
*engPtr = engine
}

108
internal/gotiny/encbase.go Normal file
View File

@ -0,0 +1,108 @@
package gotiny
import (
"time"
"unsafe"
)
func (e *Encoder) encBool(v bool) {
if e.boolBit == 0 {
e.boolPos = len(e.buf)
e.buf = append(e.buf, 0)
e.boolBit = 1
}
if v {
e.buf[e.boolPos] |= e.boolBit
}
e.boolBit <<= 1
}
func (e *Encoder) encUint64(v uint64) {
switch {
case v < 1<<7-1:
e.buf = append(e.buf, byte(v))
case v < 1<<14-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7))
case v < 1<<21-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14))
case v < 1<<28-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21))
case v < 1<<35-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28))
case v < 1<<42-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35))
case v < 1<<49-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42))
case v < 1<<56-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)|0x80, byte(v>>49))
default:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)|0x80, byte(v>>49)|0x80, byte(v>>56))
}
}
func (e *Encoder) encUint16(v uint16) {
if v < 1<<7-1 {
e.buf = append(e.buf, byte(v))
} else if v < 1<<14-1 {
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7))
} else {
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14))
}
}
func (e *Encoder) encUint32(v uint32) {
switch {
case v < 1<<7-1:
e.buf = append(e.buf, byte(v))
case v < 1<<14-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7))
case v < 1<<21-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14))
case v < 1<<28-1:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21))
default:
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28))
}
}
func (e *Encoder) encLength(v int) { e.encUint32(uint32(v)) }
func (e *Encoder) encString(s string) { e.encUint32(uint32(len(s))); e.buf = append(e.buf, s...) }
func (e *Encoder) encIsNotNil(v bool) { e.encBool(v) }
func encIgnore(*Encoder, unsafe.Pointer) {}
func encBool(e *Encoder, p unsafe.Pointer) { e.encBool(*(*bool)(p)) }
func encInt(e *Encoder, p unsafe.Pointer) { e.encUint64(int64ToUint64(int64(*(*int)(p)))) }
func encInt8(e *Encoder, p unsafe.Pointer) { e.buf = append(e.buf, *(*uint8)(p)) }
func encInt16(e *Encoder, p unsafe.Pointer) { e.encUint16(int16ToUint16(*(*int16)(p))) }
func encInt32(e *Encoder, p unsafe.Pointer) { e.encUint32(int32ToUint32(*(*int32)(p))) }
func encInt64(e *Encoder, p unsafe.Pointer) { e.encUint64(int64ToUint64(*(*int64)(p))) }
func encUint8(e *Encoder, p unsafe.Pointer) { e.buf = append(e.buf, *(*uint8)(p)) }
func encUint16(e *Encoder, p unsafe.Pointer) { e.encUint16(*(*uint16)(p)) }
func encUint32(e *Encoder, p unsafe.Pointer) { e.encUint32(*(*uint32)(p)) }
func encUint64(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uint64)(p))) }
func encUint(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uint)(p))) }
func encUintptr(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uintptr)(p))) }
func encPointer(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uintptr)(p))) }
func encFloat32(e *Encoder, p unsafe.Pointer) { e.encUint32(float32ToUint32(p)) }
func encFloat64(e *Encoder, p unsafe.Pointer) { e.encUint64(float64ToUint64(p)) }
func encString(e *Encoder, p unsafe.Pointer) {
s := *(*string)(p)
e.encUint32(uint32(len(s)))
e.buf = append(e.buf, s...)
}
func encTime(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64((*time.Time)(p).UnixNano())) }
func encComplex64(e *Encoder, p unsafe.Pointer) { e.encUint64(*(*uint64)(p)) }
func encComplex128(e *Encoder, p unsafe.Pointer) {
e.encUint64(*(*uint64)(p))
e.encUint64(*(*uint64)(unsafe.Pointer(uintptr(p) + ptr1Size)))
}
func encBytes(e *Encoder, p unsafe.Pointer) {
isNotNil := !isNil(p)
e.encIsNotNil(isNotNil)
if isNotNil {
buf := *(*[]byte)(p)
e.encLength(len(buf))
e.buf = append(e.buf, buf...)
}
}

103
internal/gotiny/encoder.go Normal file
View File

@ -0,0 +1,103 @@
package gotiny
import (
"reflect"
"unsafe"
)
type Encoder struct {
buf []byte //编码目的数组
off int
boolPos int //下一次要设置的bool在buf中的下标,即buf[boolPos]
boolBit byte //下一次要设置的bool的buf[boolPos]中的bit位
engines []encEng
length int
}
func Marshal(is ...interface{}) []byte {
return NewEncoderWithPtr(is...).Encode(is...)
}
// 创建一个编码ps 指向类型的编码器
func NewEncoderWithPtr(ps ...interface{}) *Encoder {
l := len(ps)
engines := make([]encEng, l)
for i := 0; i < l; i++ {
rt := reflect.TypeOf(ps[i])
if rt.Kind() != reflect.Ptr {
panic("must a pointer type!")
}
engines[i] = getEncEngine(rt.Elem())
}
return &Encoder{
length: l,
engines: engines,
}
}
// 创建一个编码is 类型的编码器
func NewEncoder(is ...interface{}) *Encoder {
l := len(is)
engines := make([]encEng, l)
for i := 0; i < l; i++ {
engines[i] = getEncEngine(reflect.TypeOf(is[i]))
}
return &Encoder{
length: l,
engines: engines,
}
}
func NewEncoderWithType(ts ...reflect.Type) *Encoder {
l := len(ts)
engines := make([]encEng, l)
for i := 0; i < l; i++ {
engines[i] = getEncEngine(ts[i])
}
return &Encoder{
length: l,
engines: engines,
}
}
// 入参是要编码值的指针
func (e *Encoder) Encode(is ...interface{}) []byte {
engines := e.engines
for i := 0; i < len(engines) && i < len(is); i++ {
engines[i](e, (*[2]unsafe.Pointer)(unsafe.Pointer(&is[i]))[1])
}
return e.reset()
}
// 入参是要编码的值得unsafe.Pointer 指针
func (e *Encoder) EncodePtr(ps ...unsafe.Pointer) []byte {
engines := e.engines
for i := 0; i < len(engines) && i < len(ps); i++ {
engines[i](e, ps[i])
}
return e.reset()
}
// vs 是持有要编码的值
func (e *Encoder) EncodeValue(vs ...reflect.Value) []byte {
engines := e.engines
for i := 0; i < len(engines) && i < len(vs); i++ {
engines[i](e, getUnsafePointer(&vs[i]))
}
return e.reset()
}
// 编码产生的数据将append到buf上
func (e *Encoder) AppendTo(buf []byte) {
e.off = len(buf)
e.buf = buf
}
func (e *Encoder) reset() []byte {
buf := e.buf
e.buf = buf[:e.off]
e.boolBit = 0
e.boolPos = 0
return buf
}

144
internal/gotiny/register.go Normal file
View File

@ -0,0 +1,144 @@
package gotiny
import (
"reflect"
"strconv"
)
var (
type2name = map[reflect.Type]string{}
name2type = map[string]reflect.Type{}
)
func GetName(obj interface{}) string {
return GetNameByType(reflect.TypeOf(obj))
}
func GetNameByType(rt reflect.Type) string {
return string(getName([]byte(nil), rt))
}
func getName(prefix []byte, rt reflect.Type) []byte {
if rt == nil || rt.Kind() == reflect.Invalid {
return append(prefix, []byte("<nil>")...)
}
if rt.Name() == "" { //未命名的,组合类型
switch rt.Kind() {
case reflect.Ptr:
return getName(append(prefix, '*'), rt.Elem())
case reflect.Array:
return getName(append(prefix, "["+strconv.Itoa(rt.Len())+"]"...), rt.Elem())
case reflect.Slice:
return getName(append(prefix, '[', ']'), rt.Elem())
case reflect.Struct:
prefix = append(prefix, "struct {"...)
nf := rt.NumField()
if nf > 0 {
prefix = append(prefix, ' ')
}
for i := 0; i < nf; i++ {
field := rt.Field(i)
if field.Anonymous {
prefix = getName(prefix, field.Type)
} else {
prefix = getName(append(prefix, field.Name+" "...), field.Type)
}
if i != nf-1 {
prefix = append(prefix, ';', ' ')
} else {
prefix = append(prefix, ' ')
}
}
return append(prefix, '}')
case reflect.Map:
return getName(append(getName(append(prefix, "map["...), rt.Key()), ']'), rt.Elem())
case reflect.Interface:
prefix = append(prefix, "interface {"...)
nm := rt.NumMethod()
if nm > 0 {
prefix = append(prefix, ' ')
}
for i := 0; i < nm; i++ {
method := rt.Method(i)
fn := getName([]byte(nil), method.Type)
prefix = append(prefix, method.Name+string(fn[4:])...)
if i != nm-1 {
prefix = append(prefix, ';', ' ')
} else {
prefix = append(prefix, ' ')
}
}
return append(prefix, '}')
case reflect.Func:
prefix = append(prefix, "func("...)
for i := 0; i < rt.NumIn(); i++ {
prefix = getName(prefix, rt.In(i))
if i != rt.NumIn()-1 {
prefix = append(prefix, ',', ' ')
}
}
prefix = append(prefix, ')')
no := rt.NumOut()
if no > 0 {
prefix = append(prefix, ' ')
}
if no > 1 {
prefix = append(prefix, '(')
}
for i := 0; i < no; i++ {
prefix = getName(prefix, rt.Out(i))
if i != no-1 {
prefix = append(prefix, ',', ' ')
}
}
if no > 1 {
prefix = append(prefix, ')')
}
return prefix
}
}
if rt.PkgPath() == "" {
prefix = append(prefix, rt.Name()...)
} else {
prefix = append(prefix, rt.PkgPath()+"."+rt.Name()...)
}
return prefix
}
func getNameOfType(rt reflect.Type) string {
if name, has := type2name[rt]; has {
return name
} else {
return registerType(rt)
}
}
func Register(i interface{}) string {
return registerType(reflect.TypeOf(i))
}
func registerType(rt reflect.Type) string {
name := GetNameByType(rt)
RegisterName(name, rt)
return name
}
func RegisterName(name string, rt reflect.Type) {
if name == "" {
panic("attempt to register empty name")
}
if rt == nil || rt.Kind() == reflect.Invalid {
panic("attempt to register nil type or invalid type")
}
if _, has := type2name[rt]; has {
panic("gotiny: registering duplicate types for " + GetNameByType(rt))
}
if _, has := name2type[name]; has {
panic("gotiny: registering name" + name + " is exist")
}
name2type[name] = rt
type2name[rt] = name
}

57
internal/gotiny/unsafe.go Normal file
View File

@ -0,0 +1,57 @@
package gotiny
import (
"reflect"
"unsafe"
)
const (
kindDirectIface = 1 << 5
)
// rtype is the common implementation of most values.
// It is embedded in other struct types.
//
// rtype must be kept in sync with reflect/type.go:/^type._type.
type rtype struct {
_ uintptr
_ uintptr // number of bytes in the type that can contain pointers
_ uint32 // hash of type; avoids computation in hash tables
_ uint8 // extra type information flags
_ uint8 // alignment of variable with this type
_ uint8 // alignment of struct field with this type
kind uint8 // enumeration for C
_ uintptr // algorithm table
_ uintptr // garbage collection data
_ int32 // string form
_ int32 // type for pointer to this type, may be zero
}
// ifaceIndir reports whether t is stored indirectly in an interface value.
func ifaceDirect(t *rtype) bool {
return t.kind&kindDirectIface != 0
}
func directType(rt *reflect.Type) bool {
return ifaceDirect((*rtype)((*[2]unsafe.Pointer)(unsafe.Pointer(rt))[1]))
}
type refVal struct {
_ unsafe.Pointer
ptr unsafe.Pointer
flag flag
}
type flag uintptr
//go:linkname flagIndir reflect.flagIndir
const flagIndir flag = 1 << 7
func getUnsafePointer(rv *reflect.Value) unsafe.Pointer {
vv := (*refVal)(unsafe.Pointer(rv))
if vv.flag&flagIndir == 0 {
return unsafe.Pointer(&vv.ptr)
} else {
return vv.ptr
}
}

185
internal/gotiny/utils.go Normal file
View File

@ -0,0 +1,185 @@
package gotiny
import (
"encoding"
"encoding/gob"
"reflect"
"strings"
"unsafe"
)
const (
ptr1Size = 4 << (^uintptr(0) >> 63) // unsafe.Sizeof(uintptr(0)) but an ideal const
)
func float64ToUint64(v unsafe.Pointer) uint64 {
return reverse64Byte(*(*uint64)(v))
}
func uint64ToFloat64(u uint64) float64 {
u = reverse64Byte(u)
return *((*float64)(unsafe.Pointer(&u)))
}
func reverse64Byte(u uint64) uint64 {
u = (u << 32) | (u >> 32)
u = ((u << 16) & 0xFFFF0000FFFF0000) | ((u >> 16) & 0xFFFF0000FFFF)
u = ((u << 8) & 0xFF00FF00FF00FF00) | ((u >> 8) & 0xFF00FF00FF00FF)
return u
}
func float32ToUint32(v unsafe.Pointer) uint32 {
return reverse32Byte(*(*uint32)(v))
}
func uint32ToFloat32(u uint32) float32 {
u = reverse32Byte(u)
return *((*float32)(unsafe.Pointer(&u)))
}
func reverse32Byte(u uint32) uint32 {
u = (u << 16) | (u >> 16)
return ((u << 8) & 0xFF00FF00) | ((u >> 8) & 0xFF00FF)
}
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
// uint 9 7 5 3 1 0 2 4 6 8 10 12
func int64ToUint64(v int64) uint64 {
return uint64((v << 1) ^ (v >> 63))
}
// uint 9 7 5 3 1 0 2 4 6 8 10 12
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
func uint64ToInt64(u uint64) int64 {
v := int64(u)
return (-(v & 1)) ^ (v>>1)&0x7FFFFFFFFFFFFFFF
}
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
// uint 9 7 5 3 1 0 2 4 6 8 10 12
func int32ToUint32(v int32) uint32 {
return uint32((v << 1) ^ (v >> 31))
}
// uint 9 7 5 3 1 0 2 4 6 8 10 12
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
func uint32ToInt32(u uint32) int32 {
v := int32(u)
return (-(v & 1)) ^ (v>>1)&0x7FFFFFFF
}
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
// uint 9 7 5 3 1 0 2 4 6 8 10 12
func int16ToUint16(v int16) uint16 {
return uint16((v << 1) ^ (v >> 15))
}
// uint 9 7 5 3 1 0 2 4 6 8 10 12
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
func uint16ToInt16(u uint16) int16 {
v := int16(u)
return (-(v & 1)) ^ (v>>1)&0x7FFF
}
func isNil(p unsafe.Pointer) bool {
return *(*unsafe.Pointer)(p) == nil
}
type gobInter interface {
gob.GobEncoder
gob.GobDecoder
}
type binInter interface {
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
}
// 只应该由指针来实现该接口
type GoTinySerializer interface {
// 编码方法将对象的序列化结果append到入参数并返回方法不应该修改入参数值原有的值
GotinyEncode([]byte) []byte
// 解码方法将入参解码到对象里并返回使用的长度。方法从入参的第0个字节开始使用并且不应该修改入参中的任何数据
GotinyDecode([]byte) int
}
func implementOtherSerializer(rt reflect.Type) (encEng encEng, decEng decEng) {
rtNil := reflect.Zero(reflect.PtrTo(rt)).Interface()
if _, ok := rtNil.(GoTinySerializer); ok {
encEng = func(e *Encoder, p unsafe.Pointer) {
e.buf = reflect.NewAt(rt, p).Interface().(GoTinySerializer).GotinyEncode(e.buf)
}
decEng = func(d *Decoder, p unsafe.Pointer) {
d.index += reflect.NewAt(rt, p).Interface().(GoTinySerializer).GotinyDecode(d.buf[d.index:])
}
return
}
if _, ok := rtNil.(binInter); ok {
encEng = func(e *Encoder, p unsafe.Pointer) {
buf, err := reflect.NewAt(rt, p).Interface().(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
panic(err)
}
e.encLength(len(buf))
e.buf = append(e.buf, buf...)
}
decEng = func(d *Decoder, p unsafe.Pointer) {
length := d.decLength()
start := d.index
d.index += length
if err := reflect.NewAt(rt, p).Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(d.buf[start:d.index]); err != nil {
panic(err)
}
}
return
}
if _, ok := rtNil.(gobInter); ok {
encEng = func(e *Encoder, p unsafe.Pointer) {
buf, err := reflect.NewAt(rt, p).Interface().(gob.GobEncoder).GobEncode()
if err != nil {
panic(err)
}
e.encLength(len(buf))
e.buf = append(e.buf, buf...)
}
decEng = func(d *Decoder, p unsafe.Pointer) {
length := d.decLength()
start := d.index
d.index += length
if err := reflect.NewAt(rt, p).Interface().(gob.GobDecoder).GobDecode(d.buf[start:d.index]); err != nil {
panic(err)
}
}
}
return
}
// rt.kind is reflect.struct
func getFieldType(rt reflect.Type, baseOff uintptr) (fields []reflect.Type, offs []uintptr) {
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
if ignoreField(field) {
continue
}
ft := field.Type
if ft.Kind() == reflect.Struct {
if _, engine := implementOtherSerializer(ft); engine == nil {
fFields, fOffs := getFieldType(ft, field.Offset+baseOff)
fields = append(fields, fFields...)
offs = append(offs, fOffs...)
continue
}
}
fields = append(fields, ft)
offs = append(offs, field.Offset+baseOff)
}
return
}
func ignoreField(field reflect.StructField) bool {
tinyTag, ok := field.Tag.Lookup("gotiny")
return ok && strings.TrimSpace(tinyTag) == "-"
}

90
internal/memory/memory.go Normal file
View File

@ -0,0 +1,90 @@
package memory
import (
"sync"
"sync/atomic"
"time"
)
type Storage struct {
sync.RWMutex
data map[string]item // data
ts uint64 // timestamp
}
type item struct {
v interface{} // val
e uint64 // exp
}
func New() *Storage {
store := &Storage{
data: make(map[string]item),
ts: uint64(time.Now().Unix()),
}
go store.gc(10 * time.Millisecond)
go store.updater(1 * time.Second)
return store
}
// Get value by key
func (s *Storage) Get(key string) interface{} {
s.RLock()
v, ok := s.data[key]
s.RUnlock()
if !ok || v.e != 0 && v.e <= atomic.LoadUint64(&s.ts) {
return nil
}
return v.v
}
// Set key with value
func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
var exp uint64
if ttl > 0 {
exp = uint64(ttl.Seconds()) + atomic.LoadUint64(&s.ts)
}
s.Lock()
s.data[key] = item{val, exp}
s.Unlock()
}
// Delete key by key
func (s *Storage) Delete(key string) {
s.Lock()
delete(s.data, key)
s.Unlock()
}
// Reset all keys
func (s *Storage) Reset() {
s.Lock()
s.data = make(map[string]item)
s.Unlock()
}
func (s *Storage) updater(sleep time.Duration) {
for {
time.Sleep(sleep)
atomic.StoreUint64(&s.ts, uint64(time.Now().Unix()))
}
}
func (s *Storage) gc(sleep time.Duration) {
expired := []string{}
for {
time.Sleep(sleep)
expired = expired[:0]
s.RLock()
for key, v := range s.data {
if v.e != 0 && v.e <= atomic.LoadUint64(&s.ts) {
expired = append(expired, key)
}
}
s.RUnlock()
s.Lock()
for i := range expired {
delete(s.data, expired[i])
}
s.Unlock()
}
}

View File

@ -0,0 +1,81 @@
package memory
import (
"testing"
"time"
"github.com/gofiber/fiber/v2/utils"
)
// go test -run Test_Memory -v -race
func Test_Memory(t *testing.T) {
var store = New()
var (
key = "john"
val interface{} = []byte("doe")
exp = 1 * time.Second
)
store.Set(key, val, 0)
store.Set(key, val, 0)
result := store.Get(key)
utils.AssertEqual(t, val, result)
result = store.Get("empty")
utils.AssertEqual(t, nil, result)
store.Set(key, val, exp)
time.Sleep(1100 * time.Millisecond)
result = store.Get(key)
utils.AssertEqual(t, nil, result)
store.Set(key, val, 0)
result = store.Get(key)
utils.AssertEqual(t, val, result)
store.Delete(key)
result = store.Get(key)
utils.AssertEqual(t, nil, result)
store.Set("john", val, 0)
store.Set("doe", val, 0)
store.Reset()
result = store.Get("john")
utils.AssertEqual(t, nil, result)
result = store.Get("doe")
utils.AssertEqual(t, nil, result)
}
// go test -v -run=^$ -bench=Benchmark_Memory -benchmem -count=4
func Benchmark_Memory(b *testing.B) {
keyLength := 1000
keys := make([]string, keyLength)
for i := 0; i < keyLength; i++ {
keys[i] = utils.UUID()
}
value := []string{"some", "random", "value"}
ttl := 2 * time.Second
b.Run("fiber_memory", func(b *testing.B) {
d := New()
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
for _, key := range keys {
d.Set(key, value, ttl)
}
for _, key := range keys {
_ = d.Get(key)
}
for _, key := range keys {
d.Delete(key)
}
}
})
}

View File

@ -1,33 +0,0 @@
package memory
import "time"
// Config defines the config for storage.
type Config struct {
// Time before deleting expired keys
//
// Default is 10 * time.Second
GCInterval time.Duration
}
// ConfigDefault is the default config
var ConfigDefault = Config{
GCInterval: 10 * time.Second,
}
// configDefault is a helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if int(cfg.GCInterval.Seconds()) <= 0 {
cfg.GCInterval = ConfigDefault.GCInterval
}
return cfg
}

View File

@ -1,7 +1,6 @@
package memory
import (
"errors"
"sync"
"time"
)
@ -14,23 +13,17 @@ type Storage struct {
done chan struct{}
}
// Common storage errors
var ErrNotExist = errors.New("key does not exist")
type entry struct {
data []byte
expiry int64
}
// New creates a new memory storage
func New(config ...Config) *Storage {
// Set default config
cfg := configDefault(config...)
func New() *Storage {
// Create storage
store := &Storage{
db: make(map[string]entry),
gcInterval: cfg.GCInterval,
gcInterval: 10 * time.Second,
done: make(chan struct{}),
}
@ -43,13 +36,13 @@ func New(config ...Config) *Storage {
// Get value by key
func (s *Storage) Get(key string) ([]byte, error) {
if len(key) <= 0 {
return nil, ErrNotExist
return nil, nil
}
s.mux.RLock()
v, ok := s.db[key]
s.mux.RUnlock()
if !ok || v.expiry != 0 && v.expiry <= time.Now().Unix() {
return nil, ErrNotExist
return nil, nil
}
return v.data, nil

View File

@ -0,0 +1,9 @@
Paul Borman <borman@google.com>
bmatsuo
shawnps
theory
jboverfelt
dsymonds
cd1
wallclockbuilder
dansouza

27
internal/uuid/LICENSE Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009,2014 Google Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

80
internal/uuid/dce.go Normal file
View File

@ -0,0 +1,80 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"fmt"
"os"
)
// A Domain represents a Version 2 domain
type Domain byte
// Domain constants for DCE Security (Version 2) UUIDs.
const (
Person = Domain(0)
Group = Domain(1)
Org = Domain(2)
)
// NewDCESecurity returns a DCE Security (Version 2) UUID.
//
// The domain should be one of Person, Group or Org.
// On a POSIX system the id should be the users UID for the Person
// domain and the users GID for the Group. The meaning of id for
// the domain Org or on non-POSIX systems is site defined.
//
// For a given domain/id pair the same token may be returned for up to
// 7 minutes and 10 seconds.
func NewDCESecurity(domain Domain, id uint32) (UUID, error) {
uuid, err := NewUUID()
if err == nil {
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
uuid[9] = byte(domain)
binary.BigEndian.PutUint32(uuid[0:], id)
}
return uuid, err
}
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
// domain with the id returned by os.Getuid.
//
// NewDCESecurity(Person, uint32(os.Getuid()))
func NewDCEPerson() (UUID, error) {
return NewDCESecurity(Person, uint32(os.Getuid()))
}
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
// domain with the id returned by os.Getgid.
//
// NewDCESecurity(Group, uint32(os.Getgid()))
func NewDCEGroup() (UUID, error) {
return NewDCESecurity(Group, uint32(os.Getgid()))
}
// Domain returns the domain for a Version 2 UUID. Domains are only defined
// for Version 2 UUIDs.
func (uuid UUID) Domain() Domain {
return Domain(uuid[9])
}
// ID returns the id for a Version 2 UUID. IDs are only defined for Version 2
// UUIDs.
func (uuid UUID) ID() uint32 {
return binary.BigEndian.Uint32(uuid[0:4])
}
func (d Domain) String() string {
switch d {
case Person:
return "Person"
case Group:
return "Group"
case Org:
return "Org"
}
return fmt.Sprintf("Domain%d", int(d))
}

12
internal/uuid/doc.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package uuid generates and inspects UUIDs.
//
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security
// Services.
//
// A UUID is a 16 byte (128 bit) array. UUIDs may be used as keys to
// maps or compared directly.
package uuid

53
internal/uuid/hash.go Normal file
View File

@ -0,0 +1,53 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"crypto/md5"
"crypto/sha1"
"hash"
)
// Well known namespace IDs and UUIDs
var (
NameSpaceDNS = Must(Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceURL = Must(Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceOID = Must(Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceX500 = Must(Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
Nil UUID // empty UUID, all zeros
)
// NewHash returns a new UUID derived from the hash of space concatenated with
// data generated by h. The hash should be at least 16 byte in length. The
// first 16 bytes of the hash are used to form the UUID. The version of the
// UUID will be the lower 4 bits of version. NewHash is used to implement
// NewMD5 and NewSHA1.
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
h.Reset()
h.Write(space[:])
h.Write(data)
s := h.Sum(nil)
var uuid UUID
copy(uuid[:], s)
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
return uuid
}
// NewMD5 returns a new MD5 (Version 3) UUID based on the
// supplied name space and data. It is the same as calling:
//
// NewHash(md5.New(), space, data, 3)
func NewMD5(space UUID, data []byte) UUID {
return NewHash(md5.New(), space, data, 3)
}
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
// supplied name space and data. It is the same as calling:
//
// NewHash(sha1.New(), space, data, 5)
func NewSHA1(space UUID, data []byte) UUID {
return NewHash(sha1.New(), space, data, 5)
}

38
internal/uuid/marshal.go Normal file
View File

@ -0,0 +1,38 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "fmt"
// MarshalText implements encoding.TextMarshaler.
func (uuid UUID) MarshalText() ([]byte, error) {
var js [36]byte
encodeHex(js[:], uuid)
return js[:], nil
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (uuid *UUID) UnmarshalText(data []byte) error {
id, err := ParseBytes(data)
if err != nil {
return err
}
*uuid = id
return nil
}
// MarshalBinary implements encoding.BinaryMarshaler.
func (uuid UUID) MarshalBinary() ([]byte, error) {
return uuid[:], nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler.
func (uuid *UUID) UnmarshalBinary(data []byte) error {
if len(data) != 16 {
return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
}
copy(uuid[:], data)
return nil
}

90
internal/uuid/node.go Normal file
View File

@ -0,0 +1,90 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"sync"
)
var (
nodeMu sync.Mutex
ifname string // name of interface being used
nodeID [6]byte // hardware for version 1 UUIDs
zeroID [6]byte // nodeID with only 0's
)
// NodeInterface returns the name of the interface from which the NodeID was
// derived. The interface "user" is returned if the NodeID was set by
// SetNodeID.
func NodeInterface() string {
defer nodeMu.Unlock()
nodeMu.Lock()
return ifname
}
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
// If name is "" then the first usable interface found will be used or a random
// Node ID will be generated. If a named interface cannot be found then false
// is returned.
//
// SetNodeInterface never fails when name is "".
func SetNodeInterface(name string) bool {
defer nodeMu.Unlock()
nodeMu.Lock()
return setNodeInterface(name)
}
func setNodeInterface(name string) bool {
iname, addr := getHardwareInterface(name) // null implementation for js
if iname != "" && addr != nil {
ifname = iname
copy(nodeID[:], addr)
return true
}
// We found no interfaces with a valid hardware address. If name
// does not specify a specific interface generate a random Node ID
// (section 4.1.6)
if name == "" {
ifname = "random"
randomBits(nodeID[:])
return true
}
return false
}
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
// if not already set.
func NodeID() []byte {
defer nodeMu.Unlock()
nodeMu.Lock()
if nodeID == zeroID {
setNodeInterface("")
}
nid := nodeID
return nid[:]
}
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
// of id are used. If id is less than 6 bytes then false is returned and the
// Node ID is not set.
func SetNodeID(id []byte) bool {
if len(id) < 6 {
return false
}
defer nodeMu.Unlock()
nodeMu.Lock()
copy(nodeID[:], id)
ifname = "user"
return true
}
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) NodeID() []byte {
var node [6]byte
copy(node[:], uuid[10:])
return node[:]
}

12
internal/uuid/node_js.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build js
package uuid
// getHardwareInterface returns nil values for the JS version of the code.
// This remvoves the "net" dependency, because it is not used in the browser.
// Using the "net" library inflates the size of the transpiled JS code by 673k bytes.
func getHardwareInterface(name string) (string, []byte) { return "", nil }

33
internal/uuid/node_net.go Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !js
package uuid
import "net"
var interfaces []net.Interface // cached list of interfaces
// getHardwareInterface returns the name and hardware address of interface name.
// If name is "" then the name and hardware address of one of the system's
// interfaces is returned. If no interfaces are found (name does not exist or
// there are no interfaces) then "", nil is returned.
//
// Only addresses of at least 6 bytes are returned.
func getHardwareInterface(name string) (string, []byte) {
if interfaces == nil {
var err error
interfaces, err = net.Interfaces()
if err != nil {
return "", nil
}
}
for _, ifs := range interfaces {
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
return ifs.Name, ifs.HardwareAddr
}
}
return "", nil
}

59
internal/uuid/sql.go Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"database/sql/driver"
"fmt"
)
// Scan implements sql.Scanner so UUIDs can be read from databases transparently
// Currently, database types that map to string and []byte are supported. Please
// consult database-specific driver documentation for matching types.
func (uuid *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case nil:
return nil
case string:
// if an empty UUID comes from a table, we return a null UUID
if src == "" {
return nil
}
// see Parse for required string format
u, err := Parse(src)
if err != nil {
return fmt.Errorf("Scan: %v", err)
}
*uuid = u
case []byte:
// if an empty UUID comes from a table, we return a null UUID
if len(src) == 0 {
return nil
}
// assumes a simple slice of bytes if 16 bytes
// otherwise attempts to parse
if len(src) != 16 {
return uuid.Scan(string(src))
}
copy((*uuid)[:], src)
default:
return fmt.Errorf("Scan: unable to scan type %T into UUID", src)
}
return nil
}
// Value implements sql.Valuer so that UUIDs can be written to databases
// transparently. Currently, UUIDs map to strings. Please consult
// database-specific driver documentation for matching types.
func (uuid UUID) Value() (driver.Value, error) {
return uuid.String(), nil
}

123
internal/uuid/time.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"sync"
"time"
)
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
// 1582.
type Time int64
const (
lillian = 2299160 // Julian day of 15 Oct 1582
unix = 2440587 // Julian day of 1 Jan 1970
epoch = unix - lillian // Days between epochs
g1582 = epoch * 86400 // seconds between epochs
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
)
var (
timeMu sync.Mutex
lasttime uint64 // last time we returned
clockSeq uint16 // clock sequence for this run
timeNow = time.Now // for testing
)
// UnixTime converts t the number of seconds and nanoseconds using the Unix
// epoch of 1 Jan 1970.
func (t Time) UnixTime() (sec, nsec int64) {
sec = int64(t - g1582ns100)
nsec = (sec % 10000000) * 100
sec /= 10000000
return sec, nsec
}
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
// clock sequence as well as adjusting the clock sequence as needed. An error
// is returned if the current time cannot be determined.
func GetTime() (Time, uint16, error) {
defer timeMu.Unlock()
timeMu.Lock()
return getTime()
}
func getTime() (Time, uint16, error) {
t := timeNow()
// If we don't have a clock sequence already, set one.
if clockSeq == 0 {
setClockSequence(-1)
}
now := uint64(t.UnixNano()/100) + g1582ns100
// If time has gone backwards with this clock sequence then we
// increment the clock sequence
if now <= lasttime {
clockSeq = ((clockSeq + 1) & 0x3fff) | 0x8000
}
lasttime = now
return Time(now), clockSeq, nil
}
// ClockSequence returns the current clock sequence, generating one if not
// already set. The clock sequence is only used for Version 1 UUIDs.
//
// The uuid package does not use global static storage for the clock sequence or
// the last time a UUID was generated. Unless SetClockSequence is used, a new
// random clock sequence is generated the first time a clock sequence is
// requested by ClockSequence, GetTime, or NewUUID. (section 4.2.1.1)
func ClockSequence() int {
defer timeMu.Unlock()
timeMu.Lock()
return clockSequence()
}
func clockSequence() int {
if clockSeq == 0 {
setClockSequence(-1)
}
return int(clockSeq & 0x3fff)
}
// SetClockSequence sets the clock sequence to the lower 14 bits of seq. Setting to
// -1 causes a new sequence to be generated.
func SetClockSequence(seq int) {
defer timeMu.Unlock()
timeMu.Lock()
setClockSequence(seq)
}
func setClockSequence(seq int) {
if seq == -1 {
var b [2]byte
randomBits(b[:]) // clock sequence
seq = int(b[0])<<8 | int(b[1])
}
oldSeq := clockSeq
clockSeq = uint16(seq&0x3fff) | 0x8000 // Set our variant
if oldSeq != clockSeq {
lasttime = 0
}
}
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
// uuid. The time is only defined for version 1 and 2 UUIDs.
func (uuid UUID) Time() Time {
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
return Time(time)
}
// ClockSequence returns the clock sequence encoded in uuid.
// The clock sequence is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) ClockSequence() int {
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff
}

43
internal/uuid/util.go Normal file
View File

@ -0,0 +1,43 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"io"
)
// randomBits completely fills slice b with random data.
func randomBits(b []byte) {
if _, err := io.ReadFull(rander, b); err != nil {
panic(err.Error()) // rand should never fail
}
}
// xvalues returns the value of a byte as a hexadecimal digit or 255.
var xvalues = [256]byte{
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
}
// xtob converts hex characters x1 and x2 into a byte.
func xtob(x1, x2 byte) (byte, bool) {
b1 := xvalues[x1]
b2 := xvalues[x2]
return (b1 << 4) | b2, b1 != 255 && b2 != 255
}

245
internal/uuid/uuid.go Normal file
View File

@ -0,0 +1,245 @@
// Copyright 2018 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
// 4122.
type UUID [16]byte
// A Version represents a UUID's version.
type Version byte
// A Variant represents a UUID's variant.
type Variant byte
// Constants returned by Variant.
const (
Invalid = Variant(iota) // Invalid UUID
RFC4122 // The variant specified in RFC4122
Reserved // Reserved, NCS backward compatibility.
Microsoft // Reserved, Microsoft Corporation backward compatibility.
Future // Reserved for future definition.
)
var rander = rand.Reader // random function
// Parse decodes s into a UUID or returns an error. Both the standard UUID
// forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the
// Microsoft encoding {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} and the raw hex
// encoding: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.
func Parse(s string) (UUID, error) {
var uuid UUID
switch len(s) {
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36:
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9:
if strings.ToLower(s[:9]) != "urn:uuid:" {
return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9])
}
s = s[9:]
// {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
case 36 + 2:
s = s[1:]
// xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
case 32:
var ok bool
for i := range uuid {
uuid[i], ok = xtob(s[i*2], s[i*2+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(s))
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return uuid, errors.New("invalid UUID format")
}
for i, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
v, ok := xtob(s[x], s[x+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
uuid[i] = v
}
return uuid, nil
}
// ParseBytes is like Parse, except it parses a byte slice instead of a string.
func ParseBytes(b []byte) (UUID, error) {
var uuid UUID
switch len(b) {
case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) {
return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9])
}
b = b[9:]
case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
b = b[1:]
case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
var ok bool
for i := 0; i < 32; i += 2 {
uuid[i/2], ok = xtob(b[i], b[i+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(b))
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
return uuid, errors.New("invalid UUID format")
}
for i, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
v, ok := xtob(b[x], b[x+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
uuid[i] = v
}
return uuid, nil
}
// MustParse is like Parse but panics if the string cannot be parsed.
// It simplifies safe initialization of global variables holding compiled UUIDs.
func MustParse(s string) UUID {
uuid, err := Parse(s)
if err != nil {
panic(`uuid: Parse(` + s + `): ` + err.Error())
}
return uuid
}
// FromBytes creates a new UUID from a byte slice. Returns an error if the slice
// does not have a length of 16. The bytes are copied from the slice.
func FromBytes(b []byte) (uuid UUID, err error) {
err = uuid.UnmarshalBinary(b)
return uuid, err
}
// Must returns uuid if err is nil and panics otherwise.
func Must(uuid UUID, err error) UUID {
if err != nil {
panic(err)
}
return uuid
}
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
// , or "" if uuid is invalid.
func (uuid UUID) String() string {
var buf [36]byte
encodeHex(buf[:], uuid)
return string(buf[:])
}
// URN returns the RFC 2141 URN form of uuid,
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
func (uuid UUID) URN() string {
var buf [36 + 9]byte
copy(buf[:], "urn:uuid:")
encodeHex(buf[9:], uuid)
return string(buf[:])
}
func encodeHex(dst []byte, uuid UUID) {
hex.Encode(dst, uuid[:4])
dst[8] = '-'
hex.Encode(dst[9:13], uuid[4:6])
dst[13] = '-'
hex.Encode(dst[14:18], uuid[6:8])
dst[18] = '-'
hex.Encode(dst[19:23], uuid[8:10])
dst[23] = '-'
hex.Encode(dst[24:], uuid[10:])
}
// Variant returns the variant encoded in uuid.
func (uuid UUID) Variant() Variant {
switch {
case (uuid[8] & 0xc0) == 0x80:
return RFC4122
case (uuid[8] & 0xe0) == 0xc0:
return Microsoft
case (uuid[8] & 0xe0) == 0xe0:
return Future
default:
return Reserved
}
}
// Version returns the version of uuid.
func (uuid UUID) Version() Version {
return Version(uuid[6] >> 4)
}
func (v Version) String() string {
if v > 15 {
return fmt.Sprintf("BAD_VERSION_%d", v)
}
return fmt.Sprintf("VERSION_%d", v)
}
func (v Variant) String() string {
switch v {
case RFC4122:
return "RFC4122"
case Reserved:
return "Reserved"
case Microsoft:
return "Microsoft"
case Future:
return "Future"
case Invalid:
return "Invalid"
}
return fmt.Sprintf("BadVariant%d", int(v))
}
// SetRand sets the random number generator to r, which implements io.Reader.
// If r.Read returns an error when the package requests random data then
// a panic will be issued.
//
// Calling SetRand with nil sets the random number generator to the default
// generator.
func SetRand(r io.Reader) {
if r == nil {
rander = rand.Reader
return
}
rander = r
}

44
internal/uuid/version1.go Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
)
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
// sequence, and the current time. If the NodeID has not been set by SetNodeID
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
// be set NewUUID returns nil. If clock sequence has not been set by
// SetClockSequence then it will be set automatically. If GetTime fails to
// return the current NewUUID returns nil and an error.
//
// In most cases, New should be used.
func NewUUID() (UUID, error) {
var uuid UUID
now, seq, err := GetTime()
if err != nil {
return uuid, err
}
timeLow := uint32(now & 0xffffffff)
timeMid := uint16((now >> 32) & 0xffff)
timeHi := uint16((now >> 48) & 0x0fff)
timeHi |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uuid[0:], timeLow)
binary.BigEndian.PutUint16(uuid[4:], timeMid)
binary.BigEndian.PutUint16(uuid[6:], timeHi)
binary.BigEndian.PutUint16(uuid[8:], seq)
nodeMu.Lock()
if nodeID == zeroID {
setNodeInterface("")
}
copy(uuid[10:], nodeID[:])
nodeMu.Unlock()
return uuid, nil
}

43
internal/uuid/version4.go Normal file
View File

@ -0,0 +1,43 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "io"
// New creates a new random UUID or panics. New is equivalent to
// the expression
//
// uuid.Must(uuid.NewRandom())
func New() UUID {
return Must(NewRandom())
}
// NewRandom returns a Random (Version 4) UUID.
//
// The strength of the UUIDs is based on the strength of the crypto/rand
// package.
//
// A note about uniqueness derived from the UUID Wikipedia entry:
//
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
// hit by a meteorite is estimated to be one chance in 17 billion, that
// means the probability is about 0.00000000006 (6 × 1011),
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
// year and having one duplicate.
func NewRandom() (UUID, error) {
return NewRandomFromReader(rander)
}
// NewRandomFromReader returns a UUID based on bytes read from a given io.Reader.
func NewRandomFromReader(r io.Reader) (UUID, error) {
var uuid UUID
_, err := io.ReadFull(r, uuid[:])
if err != nil {
return Nil, err
}
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid, nil
}

View File

@ -75,12 +75,12 @@ type Config struct {
// Default: func(c *fiber.Ctx) string {
// return c.Path()
// }
Key func(*fiber.Ctx) string
KeyGenerator func(*fiber.Ctx) string
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store fiber.Storage
Storage fiber.Storage
}
```
@ -92,8 +92,9 @@ var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
Key: func(c *fiber.Ctx) string {
KeyGenerator: func(c *fiber.Ctx) string {
return c.Path()
},
Storage: nil,
}
```

View File

@ -17,15 +17,21 @@ func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c *fiber.Ctx) error {
return c.Next()
}
}
var (
// Cache settings
mux = &sync.RWMutex{}
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
mux = &sync.RWMutex{}
// Default store logic (if no Store is provided)
entries = make(map[string]entry)
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
go func() {
@ -35,30 +41,6 @@ func New(config ...Config) fiber.Handler {
}
}()
// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c *fiber.Ctx) error {
return c.Next()
}
}
// Remove expired entries
if cfg.defaultStore {
go func() {
for {
// GC the entries every 10 seconds
time.Sleep(10 * time.Second)
mux.Lock()
for k := range entries {
if atomic.LoadUint64(&timestamp) >= entries[k].exp {
delete(entries, k)
}
}
mux.Unlock()
}
}()
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
@ -72,74 +54,45 @@ func New(config ...Config) fiber.Handler {
}
// Get key from request
key := cfg.Key(c)
key := cfg.KeyGenerator(c)
// Create new entry
var entry entry
var entryBody []byte
// Get entry from pool
e := manager.get(key)
// Lock entry
// Lock entry and unlock when finished
mux.Lock()
defer mux.Unlock()
// Check if we need to use the default in-memory storage
if cfg.defaultStore {
entry = entries[key]
} else {
// Load data from store
storeEntry, err := cfg.Storage.Get(key)
if err != nil {
return err
}
// Only decode if we found an entry
if storeEntry != nil {
// Decode bytes using msgp
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
return err
}
}
if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil {
return err
}
}
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Set expiration if entry does not exist
if entry.exp == 0 {
entry.exp = ts + expiration
if e.exp == 0 {
// Set expiration if entry does not exist
e.exp = ts + expiration
} else if ts >= entry.exp {
} else if ts >= e.exp {
// Check if entry is expired
// Use default memory storage
if cfg.defaultStore {
delete(entries, key)
} else { // Use custom storage
if err := cfg.Storage.Delete(key); err != nil {
return err
}
if err := cfg.Storage.Delete(key + "_body"); err != nil {
return err
}
manager.delete(key)
// External storage saves body data with different key
if cfg.Storage != nil {
manager.delete(key + "_body")
}
} else {
if cfg.defaultStore {
c.Response().SetBodyRaw(entry.body)
} else {
c.Response().SetBodyRaw(entryBody)
// Seperate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
}
// Set response headers from cache
c.Response().SetStatusCode(entry.status)
c.Response().Header.SetContentTypeBytes(entry.cType)
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
c.Response().Header.SetContentTypeBytes(e.ctype)
if len(e.cencoding) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
}
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(entry.exp-ts, 10)
maxAge := strconv.FormatUint(e.exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
@ -153,31 +106,21 @@ func New(config ...Config) fiber.Handler {
}
// Cache response
entryBody = utils.SafeBytes(c.Response().Body())
entry.status = c.Response().StatusCode()
entry.cType = utils.SafeBytes(c.Response().Header.ContentType())
// Use default memory storage
if cfg.defaultStore {
entry.body = entryBody
entries[key] = entry
e.body = utils.SafeBytes(c.Response().Body())
e.status = c.Response().StatusCode()
e.ctype = utils.SafeBytes(c.Response().Header.ContentType())
e.cencoding = utils.SafeBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
// For external Storage we store raw body seperated
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, cfg.Expiration)
// avoid body msgp encoding
e.body = nil
manager.set(key, e, cfg.Expiration)
manager.release(e)
} else {
// Use custom storage
data, err := entry.MarshalMsg(nil)
if err != nil {
return err
}
// Pass bytes to Storage
if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil {
return err
}
// Pass bytes to Storage
if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
return err
}
// Store entry in memory
manager.set(key, e, cfg.Expiration)
}
// Finish response

View File

@ -6,13 +6,12 @@ import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -93,54 +92,54 @@ func Test_Cache(t *testing.T) {
utils.AssertEqual(t, cachedBody, body)
}
// go test -run Test_Cache_Concurrency_Store -race -v
func Test_Cache_Concurrency_Store(t *testing.T) {
// Test concurrency using a custom store
// // go test -run Test_Cache_Concurrency_Storage -race -v
// func Test_Cache_Concurrency_Storage(t *testing.T) {
// // Test concurrency using a custom store
app := fiber.New()
// app := fiber.New()
app.Use(New(Config{
Store: testStore{stmap: map[string][]byte{}, mutex: &sync.RWMutex{}},
}))
// app.Use(New(Config{
// Storage: memory.New(),
// }))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
// app.Get("/", func(c *fiber.Ctx) error {
// return c.SendString("Hello tester!")
// })
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// var wg sync.WaitGroup
// singleRequest := func(wg *sync.WaitGroup) {
// defer wg.Done()
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}
// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Hello tester!", string(body))
// }
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
// for i := 0; i <= 49; i++ {
// wg.Add(1)
// go singleRequest(&wg)
// }
wg.Wait()
// wg.Wait()
req := httptest.NewRequest("GET", "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
// req := httptest.NewRequest("GET", "/", nil)
// resp, err := app.Test(req)
// utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest("GET", "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
// cachedReq := httptest.NewRequest("GET", "/", nil)
// cachedResp, err := app.Test(cachedReq)
// utils.AssertEqual(t, nil, err)
body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// cachedBody, err := ioutil.ReadAll(cachedResp.Body)
// utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cachedBody, body)
}
// utils.AssertEqual(t, cachedBody, body)
// }
func Test_Cache_Invalid_Expiration(t *testing.T) {
app := fiber.New()
@ -235,7 +234,7 @@ func Test_Cache_NothingToCache(t *testing.T) {
func Test_CustomKey(t *testing.T) {
app := fiber.New()
var called bool
app.Use(New(Config{Key: func(c *fiber.Ctx) string {
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
called = true
return c.Path()
}}))
@ -279,12 +278,12 @@ func Benchmark_Cache(b *testing.B) {
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
}
// go test -v -run=^$ -bench=Benchmark_Cache_Store -benchmem -count=4
func Benchmark_Cache_Store(b *testing.B) {
// go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4
func Benchmark_Cache_Storage(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Store: testStore{stmap: map[string][]byte{}, mutex: &sync.RWMutex{}},
Storage: memory.New(),
}))
app.Get("/demo", func(c *fiber.Ctx) error {
@ -308,43 +307,3 @@ func Benchmark_Cache_Store(b *testing.B) {
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
}
// testStore is used for testing custom stores
type testStore struct {
stmap map[string][]byte
mutex *sync.RWMutex
}
func (s testStore) Get(id string) ([]byte, error) {
s.mutex.RLock()
val, ok := s.stmap[id]
s.mutex.RUnlock()
if !ok {
return nil, nil
} else {
return val, nil
}
}
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()
return nil
}
func (s testStore) Reset() error {
s.stmap = map[string][]byte{}
return nil
}
func (s testStore) Delete(id string) error {
s.mutex.Lock()
delete(s.stmap, id)
s.mutex.Unlock()
return nil
}
func (s testStore) Close() error {
return nil
}

View File

@ -29,19 +29,18 @@ type Config struct {
// Default: func(c *fiber.Ctx) string {
// return c.Path()
// }
Key func(*fiber.Ctx) string
// Deprecated, use Storage instead
Store fiber.Storage
KeyGenerator func(*fiber.Ctx) string
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
defaultStore bool
// Deprecated, use Storage instead
Store fiber.Storage
// Deprecated, use KeyGenerator instead
Key func(*fiber.Ctx) string
}
// ConfigDefault is the default config
@ -49,10 +48,10 @@ var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
Key: func(c *fiber.Ctx) string {
KeyGenerator: func(c *fiber.Ctx) string {
return c.Path()
},
defaultStore: true,
Storage: nil,
}
// Helper function to set default values
@ -66,22 +65,22 @@ func configDefault(config ...Config) Config {
cfg := config[0]
// Set default values
if cfg.Store != nil {
fmt.Println("[CACHE] Store is deprecated, please use Storage")
cfg.Storage = cfg.Store
}
if cfg.Key != nil {
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator")
cfg.KeyGenerator = cfg.Key
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.Key == nil {
cfg.Key = ConfigDefault.Key
}
if cfg.Storage == nil && cfg.Store == nil {
cfg.defaultStore = true
}
if cfg.Store != nil {
fmt.Println("cache: `Store` is deprecated, use `Storage` instead")
cfg.Storage = cfg.Store
cfg.defaultStore = true
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
return cfg
}

122
middleware/cache/manager.go vendored Normal file
View File

@ -0,0 +1,122 @@
package cache
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type item struct {
body []byte
ctype []byte
cencoding []byte
status int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
e.body = nil
e.ctype = nil
e.status = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -7,7 +7,7 @@ import (
)
// DecodeMsg implements msgp.Decodable
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
@ -30,10 +30,16 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
err = msgp.WrapError(err, "body")
return
}
case "cType":
z.cType, err = dc.ReadBytes(z.cType)
case "ctype":
z.ctype, err = dc.ReadBytes(z.ctype)
if err != nil {
err = msgp.WrapError(err, "cType")
err = msgp.WrapError(err, "ctype")
return
}
case "cencoding":
z.cencoding, err = dc.ReadBytes(z.cencoding)
if err != nil {
err = msgp.WrapError(err, "cencoding")
return
}
case "status":
@ -60,10 +66,10 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
}
// EncodeMsg implements msgp.Encodable
func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 4
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 5
// write "body"
err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
err = en.Append(0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79)
if err != nil {
return
}
@ -72,14 +78,24 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
err = msgp.WrapError(err, "body")
return
}
// write "cType"
err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
// write "ctype"
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
if err != nil {
return
}
err = en.WriteBytes(z.cType)
err = en.WriteBytes(z.ctype)
if err != nil {
err = msgp.WrapError(err, "cType")
err = msgp.WrapError(err, "ctype")
return
}
// write "cencoding"
err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
if err != nil {
return
}
err = en.WriteBytes(z.cencoding)
if err != nil {
err = msgp.WrapError(err, "cencoding")
return
}
// write "status"
@ -106,15 +122,18 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
}
// MarshalMsg implements msgp.Marshaler
func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 4
// map header, size 5
// string "body"
o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
o = append(o, 0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79)
o = msgp.AppendBytes(o, z.body)
// string "cType"
o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.cType)
// string "ctype"
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.ctype)
// string "cencoding"
o = append(o, 0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
o = msgp.AppendBytes(o, z.cencoding)
// string "status"
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
o = msgp.AppendInt(o, z.status)
@ -125,7 +144,7 @@ func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
@ -148,10 +167,16 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
err = msgp.WrapError(err, "body")
return
}
case "cType":
z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType)
case "ctype":
z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype)
if err != nil {
err = msgp.WrapError(err, "cType")
err = msgp.WrapError(err, "ctype")
return
}
case "cencoding":
z.cencoding, bts, err = msgp.ReadBytesBytes(bts, z.cencoding)
if err != nil {
err = msgp.WrapError(err, "cencoding")
return
}
case "status":
@ -179,7 +204,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *entry) Msgsize() (s int) {
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
func (z *item) Msgsize() (s int) {
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
return
}

View File

@ -1,12 +0,0 @@
package cache
// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
body []byte `msg:"body"`
cType []byte `msg:"cType"`
status int `msg:"status"`
exp uint64 `msg:"exp"`
}

View File

@ -23,9 +23,9 @@ _NOTE: This middleware uses our [Storage](https://github.com/gofiber/storage) pa
func New(config ...Config) fiber.Handler
```
## Examples
### Examples
First import the middleware from Fiber,
Import the middleware package that is part of the Fiber web framework
```go
import (
@ -34,11 +34,7 @@ import (
)
```
Then create a Fiber app with `app := fiber.New()`.
### Default Config
Then apply the middleware to your Fiber app,
After you initiate your Fiber app, you can use the following possibilities:
```go
app.Use(csrf.New()) // Default config
@ -90,7 +86,7 @@ type Config struct {
KeyLookup string
// Name of the session cookie. This cookie will store session key.
// Optional. Default value "_csrf".
// Optional. Default value "csrf_".
CookieName string
// Domain of the CSRF cookie.
@ -109,7 +105,7 @@ type Config struct {
// Optional. Default value false.
CookieHTTPOnly bool
// Indicates if CSRF cookie is HTTP only.
// Indicates if CSRF cookie is requested by SameSite.
// Optional. Default value "Strict".
CookieSameSite string

View File

@ -28,7 +28,7 @@ type Config struct {
KeyLookup string
// Name of the session cookie. This cookie will store session key.
// Optional. Default value "_csrf".
// Optional. Default value "csrf_".
CookieName string
// Domain of the CSRF cookie.

View File

@ -2,13 +2,11 @@ package csrf
import (
"errors"
"fmt"
"net/textproto"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
)
// New creates a new middleware handler
@ -16,10 +14,8 @@ func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Set default values
if cfg.Storage == nil {
cfg.Storage = memory.New()
}
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
@ -39,14 +35,10 @@ func New(config ...Config) fiber.Handler {
case "param":
extractor = csrfFromParam(selectors[1])
case "cookie":
if selectors[1] == cfg.CookieName {
panic(fmt.Sprintf("KeyLookup key %s can't be the same as CookieName %s", selectors[1], cfg.CookieName))
}
extractor = csrfFromCookie(selectors[1])
}
// We only use Keys in Storage, so we need a dummy value
dummyVal := []byte{'+'}
dummyValue := []byte{'+'}
// Return new handler
return func(c *fiber.Ctx) (err error) {
@ -69,9 +61,7 @@ func New(config ...Config) fiber.Handler {
token = cfg.KeyGenerator()
// Add token to Storage
if err = cfg.Storage.Set(token, dummyVal, cfg.Expiration); err != nil {
fmt.Println("[CSRF]", err.Error())
}
manager.setRaw(token, dummyValue, cfg.Expiration)
}
// Create cookie to pass token to client
@ -85,22 +75,19 @@ func New(config ...Config) fiber.Handler {
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
}
// Set cookie to response
c.Cookie(cookie)
case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut:
// Verify CSRF token
// Extract token from client request i.e. header, query, param, form or cookie
token, err = extractor(c)
if err != nil {
return fiber.ErrForbidden
}
// We have a problem extracting the csrf token from Storage
if _, err = cfg.Storage.Get(token); err != nil {
// The token is invalid, let client generate a new one
if err = cfg.Storage.Delete(token); err != nil {
fmt.Println("[CSRF]", err.Error())
}
// 403 if token does not exist in Storage
if manager.getRaw(token) == nil {
// Expire cookie
c.Cookie(&fiber.Cookie{
Name: cfg.CookieName,
@ -111,8 +98,13 @@ func New(config ...Config) fiber.Handler {
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
})
// Return 403 Forbidden
return fiber.ErrForbidden
}
// The token is validated, time to delete it
manager.delete(token)
}
// Protect clients from caching the response by telling the browser

112
middleware/csrf/manager.go Normal file
View File

@ -0,0 +1,112 @@
package csrf
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type item struct {
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -0,0 +1,90 @@
package csrf
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/gofiber/fiber/v2/internal/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 0
err = en.Append(0x80)
if err != nil {
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 0
o = append(o, 0x80)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1
return
}

View File

@ -59,6 +59,11 @@ app.Get("/", func(c *fiber.Ctx) error {
```go
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Weak indicates that a weak validator is used. Weak etags are easy
// to generate, but are far less useful for comparisons. Strong
// validators are ideal for comparisons but can be very difficult
@ -68,11 +73,6 @@ type Config struct {
// when byte range requests are used, but strong etags mean range
// requests can still be cached.
Weak bool
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
}
```
@ -80,7 +80,7 @@ type Config struct {
```go
var ConfigDefault = Config{
Weak: false,
Next: nil,
Weak: false,
}
```

44
middleware/etag/config.go Normal file
View File

@ -0,0 +1,44 @@
package etag
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Weak indicates that a weak validator is used. Weak etags are easy
// to generate, but are far less useful for comparisons. Strong
// validators are ideal for comparisons but can be very difficult
// to generate efficiently. Weak ETag values of two representations
// of the same resources might be semantically equivalent, but not
// byte-for-byte identical. This means weak etags prevent caching
// when byte range requests are used, but strong etags mean range
// requests can still be cached.
Weak bool
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Weak: false,
Next: nil,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
return cfg
}

View File

@ -8,42 +8,13 @@ import (
"github.com/gofiber/fiber/v2/internal/bytebufferpool"
)
// Config defines the config for middleware.
type Config struct {
// Weak indicates that a weak validator is used. Weak etags are easy
// to generate, but are far less useful for comparisons. Strong
// validators are ideal for comparisons but can be very difficult
// to generate efficiently. Weak ETag values of two representations
// of the same resources might be semantically equivalent, but not
// byte-for-byte identical. This means weak etags prevent caching
// when byte range requests are used, but strong etags mean range
// requests can still be cached.
Weak bool
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Weak: false,
Next: nil,
}
var normalizedHeaderETag = []byte("Etag")
var weakPrefix = []byte("W/")
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
}
cfg := configDefault(config...)
var crc32q = crc32.MakeTable(0xD5828281)

View File

@ -1,7 +1,6 @@
package limiter
import (
"fmt"
"strconv"
"sync"
"sync/atomic"
@ -26,16 +25,16 @@ func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
var (
// Limiter settings
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
mux = &sync.RWMutex{}
// Default store logic (if no Store is provided)
entries = make(map[string]entry)
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
go func() {
for {
@ -54,65 +53,39 @@ func New(config ...Config) fiber.Handler {
// Get key from request
key := cfg.KeyGenerator(c)
// Create new entry
entry := entry{}
// Lock entry
mux.Lock()
defer mux.Unlock()
// Use Storage if provided
if cfg.Storage != nil {
val, err := cfg.Storage.Get(key)
if val != nil && len(val) > 0 {
if _, err := entry.UnmarshalMsg(val); err != nil {
return err
}
}
if err != nil && err.Error() != errNotExist {
fmt.Println("[LIMITER]", err.Error())
}
} else {
entry = entries[key]
}
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Set expiration if entry does not exist
if entry.exp == 0 {
entry.exp = ts + expiration
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= entry.exp {
} else if ts >= e.exp {
// Check if entry is expired
entry.hits = 0
entry.exp = ts + expiration
e.hits = 0
e.exp = ts + expiration
}
// Increment hits
entry.hits++
// Use Storage if provided
if cfg.Storage != nil {
// Marshal entry to bytes
val, err := entry.MarshalMsg(nil)
if err != nil {
return err
}
// Pass value to Storage
if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
return err
}
} else {
entries[key] = entry
}
e.hits++
// Calculate when it resets in seconds
expire := entry.exp - ts
expire := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - entry.hits
remaining := cfg.Max - e.hits
// Update storage
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {

View File

@ -1,7 +1,6 @@
package limiter
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -172,7 +171,6 @@ func Test_Limiter_Headers(t *testing.T) {
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
}
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
fmt.Println(v)
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
}
}

View File

@ -0,0 +1,115 @@
package limiter
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type item struct {
hits int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item)
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
e.hits = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) (it *item) {
if m.storage != nil {
it = m.acquire()
if raw, _ := m.storage.Get(key); raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return
}
}
return
}
if it, _ = m.memory.Get(key).(*item); it == nil {
it = m.acquire()
}
return
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) (raw []byte) {
if m.storage != nil {
raw, _ = m.storage.Get(key)
} else {
raw, _ = m.memory.Get(key).([]byte)
}
return
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp)
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp)
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) delete(key string) {
if m.storage != nil {
_ = m.storage.Delete(key)
} else {
m.memory.Delete(key)
}
}

View File

@ -7,7 +7,7 @@ import (
)
// DecodeMsg implements msgp.Decodable
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
@ -48,7 +48,7 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
}
// EncodeMsg implements msgp.Encodable
func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "hits"
err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
@ -74,7 +74,7 @@ func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
}
// MarshalMsg implements msgp.Marshaler
func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "hits"
@ -87,7 +87,7 @@ func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
@ -129,7 +129,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z entry) Msgsize() (s int) {
func (z item) Msgsize() (s int) {
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
return
}

View File

@ -1,10 +0,0 @@
package limiter
// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
hits int `msg:"hits"`
exp uint64 `msg:"exp"`
}

View File

@ -129,7 +129,7 @@ func New(config ...Config) fiber.Handler {
// Set error handler once
once.Do(func() {
errHandler = c.App().Config().ErrorHandler
// get longested possible path
stack := c.App().Stack()
for m := range stack {
for r := range stack[m] {
@ -139,7 +139,8 @@ func New(config ...Config) fiber.Handler {
}
}
}
// override error handler
errHandler = c.App().Config().ErrorHandler
})
// Set latency start time

View File

@ -14,13 +14,25 @@ _NOTE: This middleware uses our [Storage](https://github.com/gofiber/storage) pa
## Signatures
```go
func New(config ...Config) fiber.Handler
func New(config ...Config) *Store
func (s *Store) RegisterType(i interface{})
func (s *Store) Get(c *fiber.Ctx) (*Session, error)
func (s *Store) Reset() error
func (s *Session) Get(key string) interface{}
func (s *Session) Set(key string, val interface{})
func (s *Session) Delete(key string)
func (s *Session) Destroy() error
func (s *Session) Regenerate() error
func (s *Session) Save() error
func (s *Session) Fresh() bool
func (s *Session) ID() string
```
## Examples
First import the middleware from Fiber,
**⚠ _Storing `interface{}` values are limited to built-ins Go types_**
### Examples
Import the middleware package that is part of the Fiber web framework
```go
import (
"github.com/gofiber/fiber/v2"
@ -45,9 +57,6 @@ app.Get("/", func(c *fiber.Ctx) error {
panic(err)
}
// save session
defer sess.Save()
// Get value
name := sess.Get("name")
@ -62,6 +71,11 @@ app.Get("/", func(c *fiber.Ctx) error {
panic(err)
}
// save session
if err := sess.Save(); err != nil {
panic(err)
}
return fmt.Fprintf(ctx, "Welcome %v", name)
})
```

View File

@ -42,7 +42,7 @@ type Config struct {
CookieSameSite string
// KeyGenerator generates the session key.
// Optional. Default value utils.UUID
// Optional. Default value utils.UUIDv4
KeyGenerator func() string
}
@ -50,7 +50,7 @@ type Config struct {
var ConfigDefault = Config{
Expiration: 24 * time.Hour,
CookieName: "session_id",
KeyGenerator: utils.UUID,
KeyGenerator: utils.UUIDv4,
}
// Helper function to set default values
@ -67,6 +67,9 @@ func configDefault(config ...Config) Config {
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.CookieName == "" {
cfg.CookieName = ConfigDefault.CookieName
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}

View File

@ -0,0 +1,62 @@
package session
import (
"sync"
)
// go:generate msgp
// msgp -file="data.go" -o="data_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type data struct {
sync.RWMutex `gotiny:"-"`
d map[string]interface{} `gotiny:"d"`
}
var dataPool = sync.Pool{
New: func() interface{} {
d := new(data)
d.d = make(map[string]interface{})
return d
},
}
func acquireData() *data {
return dataPool.Get().(*data)
}
func releaseData(d *data) {
d.Reset()
dataPool.Put(d)
}
func (d *data) Reset() {
d.Lock()
for key := range d.d {
delete(d.d, key)
}
d.Unlock()
}
func (d *data) Get(key string) interface{} {
d.RLock()
v := d.d[key]
d.RUnlock()
return v
}
func (d *data) Set(key string, value interface{}) {
d.Lock()
d.d[key] = value
d.Unlock()
}
func (d *data) Delete(key string) {
d.Lock()
delete(d.d, key)
d.Unlock()
}
func (d *data) Len() int {
return len(d.d)
}

View File

@ -1,84 +0,0 @@
package session
// go:generate msgp
// msgp -file="db.go" -o="db_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type db struct {
d []kv
}
// go:generate msgp
type kv struct {
k string
v interface{}
}
func (d *db) Reset() {
d.d = d.d[:0]
}
func (d *db) Get(key string) interface{} {
idx := d.indexOf(key)
if idx > -1 {
return d.d[idx].v
}
return nil
}
func (d *db) Set(key string, value interface{}) {
idx := d.indexOf(key)
if idx > -1 {
kv := &d.d[idx]
kv.v = value
} else {
d.append(key, value)
}
}
func (d *db) Delete(key string) {
idx := d.indexOf(key)
if idx > -1 {
n := len(d.d) - 1
d.swap(idx, n)
d.d = d.d[:n]
}
}
func (d *db) Len() int {
return len(d.d)
}
func (d *db) swap(i, j int) {
iKey, iValue := d.d[i].k, d.d[i].v
jKey, jValue := d.d[j].k, d.d[j].v
d.d[i].k, d.d[i].v = jKey, jValue
d.d[j].k, d.d[j].v = iKey, iValue
}
func (d *db) allocPage() *kv {
n := len(d.d)
if cap(d.d) > n {
d.d = d.d[:n+1]
} else {
d.d = append(d.d, kv{})
}
return &d.d[n]
}
func (d *db) append(key string, value interface{}) {
kv := d.allocPage()
kv.k = key
kv.v = value
}
func (d *db) indexOf(key string) int {
n := len(d.d)
for i := 0; i < n; i++ {
if d.d[i].k == key {
return i
}
}
return -1
}

View File

@ -1,365 +0,0 @@
package session
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/gofiber/fiber/v2/internal/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *db) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "d":
var zb0002 uint32
zb0002, err = dc.ReadArrayHeader()
if err != nil {
err = msgp.WrapError(err, "d")
return
}
if cap(z.d) >= int(zb0002) {
z.d = (z.d)[:zb0002]
} else {
z.d = make([]kv, zb0002)
}
for za0001 := range z.d {
var zb0003 uint32
zb0003, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
for zb0003 > 0 {
zb0003--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.d[za0001].k, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "d", za0001, "k")
return
}
case "v":
z.d[za0001].v, err = dc.ReadIntf()
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
}
}
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *db) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 1
// write "d"
err = en.Append(0x81, 0xa1, 0x64)
if err != nil {
return
}
err = en.WriteArrayHeader(uint32(len(z.d)))
if err != nil {
err = msgp.WrapError(err, "d")
return
}
for za0001 := range z.d {
// map header, size 2
// write "k"
err = en.Append(0x82, 0xa1, 0x6b)
if err != nil {
return
}
err = en.WriteString(z.d[za0001].k)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "k")
return
}
// write "v"
err = en.Append(0xa1, 0x76)
if err != nil {
return
}
err = en.WriteIntf(z.d[za0001].v)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *db) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 1
// string "d"
o = append(o, 0x81, 0xa1, 0x64)
o = msgp.AppendArrayHeader(o, uint32(len(z.d)))
for za0001 := range z.d {
// map header, size 2
// string "k"
o = append(o, 0x82, 0xa1, 0x6b)
o = msgp.AppendString(o, z.d[za0001].k)
// string "v"
o = append(o, 0xa1, 0x76)
o, err = msgp.AppendIntf(o, z.d[za0001].v)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
}
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "d":
var zb0002 uint32
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d")
return
}
if cap(z.d) >= int(zb0002) {
z.d = (z.d)[:zb0002]
} else {
z.d = make([]kv, zb0002)
}
for za0001 := range z.d {
var zb0003 uint32
zb0003, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
for zb0003 > 0 {
zb0003--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.d[za0001].k, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "k")
return
}
case "v":
z.d[za0001].v, bts, err = msgp.ReadIntfBytes(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001, "v")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err, "d", za0001)
return
}
}
}
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *db) Msgsize() (s int) {
s = 1 + 2 + msgp.ArrayHeaderSize
for za0001 := range z.d {
s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v)
}
return
}
// DecodeMsg implements msgp.Decodable
func (z *kv) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.k, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "k")
return
}
case "v":
z.v, err = dc.ReadIntf()
if err != nil {
err = msgp.WrapError(err, "v")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z kv) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "k"
err = en.Append(0x82, 0xa1, 0x6b)
if err != nil {
return
}
err = en.WriteString(z.k)
if err != nil {
err = msgp.WrapError(err, "k")
return
}
// write "v"
err = en.Append(0xa1, 0x76)
if err != nil {
return
}
err = en.WriteIntf(z.v)
if err != nil {
err = msgp.WrapError(err, "v")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z kv) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "k"
o = append(o, 0x82, 0xa1, 0x6b)
o = msgp.AppendString(o, z.k)
// string "v"
o = append(o, 0xa1, 0x76)
o, err = msgp.AppendIntf(o, z.v)
if err != nil {
err = msgp.WrapError(err, "v")
return
}
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *kv) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "k":
z.k, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "k")
return
}
case "v":
z.v, bts, err = msgp.ReadIntfBytes(bts)
if err != nil {
err = msgp.WrapError(err, "v")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z kv) Msgsize() (s int) {
s = 1 + 2 + msgp.StringPrefixSize + len(z.k) + 2 + msgp.GuessSize(z.v)
return
}

View File

@ -5,16 +5,17 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/gotiny"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
type Session struct {
ctx *fiber.Ctx
config *Store
db *db
id string
fresh bool
id string // session id
fresh bool // if new session
ctx *fiber.Ctx // fiber context
config *Store // store configuration
data *data // key value data
}
var sessionPool = sync.Pool{
@ -25,19 +26,20 @@ var sessionPool = sync.Pool{
func acquireSession() *Session {
s := sessionPool.Get().(*Session)
s.db = new(db)
if s.data == nil {
s.data = acquireData()
}
s.fresh = true
return s
}
func releaseSession(s *Session) {
s.id = ""
s.ctx = nil
s.config = nil
if s.db != nil {
s.db.Reset()
if s.data != nil {
s.data.Reset()
}
s.id = ""
s.fresh = true
sessionPool.Put(s)
}
@ -53,25 +55,42 @@ func (s *Session) ID() string {
// Get will return the value
func (s *Session) Get(key string) interface{} {
return s.db.Get(key)
// Better safe than sorry
if s.data == nil {
return nil
}
return s.data.Get(key)
}
// Set will update or create a new key value
func (s *Session) Set(key string, val interface{}) {
s.db.Set(key, val)
// Better safe than sorry
if s.data == nil {
return
}
s.data.Set(key, val)
}
// Delete will delete the value
func (s *Session) Delete(key string) {
s.db.Delete(key)
// Better safe than sorry
if s.data == nil {
return
}
s.data.Delete(key)
}
// Destroy will delete the session from Storage and expire session cookie
func (s *Session) Destroy() error {
// Reset local data
s.db.Reset()
// Better safe than sorry
if s.data == nil {
return nil
}
// Delete data from storage
// Reset local data
s.data.Reset()
// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
}
@ -88,6 +107,7 @@ func (s *Session) Regenerate() error {
if err := s.config.Storage.Delete(s.id); err != nil {
return err
}
// Create new ID
s.id = s.config.KeyGenerator()
@ -96,26 +116,34 @@ func (s *Session) Regenerate() error {
// Save will update the storage and client cookie
func (s *Session) Save() error {
// Don't save to Storage if no data is available
if s.db.Len() <= 0 {
// Better safe than sorry
if s.data == nil {
return nil
}
// Convert book to bytes
data, err := s.db.MarshalMsg(nil)
if err != nil {
return err
// Create cookie with the session ID if fresh
if s.fresh {
s.setCookie()
}
// Don't save to Storage if no data is available
if s.data.Len() <= 0 {
return nil
}
// Convert data to bytes
mux.Lock()
data := gotiny.Marshal(&s.data)
mux.Unlock()
// pass raw bytes with session id to provider
if err := s.config.Storage.Set(s.id, data, s.config.Expiration); err != nil {
return err
}
// Create cookie with the session ID
s.setCookie()
// release session to pool to be re-used on next request
// Release session
// TODO: It's not safe to use the Session after called Save()
releaseSession(s)
return nil

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
@ -66,6 +67,108 @@ func Test_Session(t *testing.T) {
utils.AssertEqual(t, 36, len(id))
}
// go test -run Test_Session_Types
func Test_Session_Types(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.CookieName, "123")
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
type User struct {
Name string
}
var vuser = User{
Name: "John",
}
// set value
var vbool bool = true
var vstring string = "str"
var vint int = 13
var vint8 int8 = 13
var vint16 int16 = 13
var vint32 int32 = 13
var vint64 int64 = 13
var vuint uint = 13
var vuint8 uint8 = 13
var vuint16 uint16 = 13
var vuint32 uint32 = 13
var vuint64 uint64 = 13
var vuintptr uintptr = 13
var vbyte byte = 'k'
var vrune rune = 'k'
var vfloat32 float32 = 13
var vfloat64 float64 = 13
var vcomplex64 complex64 = 13
var vcomplex128 complex128 = 13
sess.Set("vuser", vuser)
sess.Set("vbool", vbool)
sess.Set("vstring", vstring)
sess.Set("vint", vint)
sess.Set("vint8", vint8)
sess.Set("vint16", vint16)
sess.Set("vint32", vint32)
sess.Set("vint64", vint64)
sess.Set("vuint", vuint)
sess.Set("vuint8", vuint8)
sess.Set("vuint16", vuint16)
sess.Set("vuint32", vuint32)
sess.Set("vuint32", vuint32)
sess.Set("vuint64", vuint64)
sess.Set("vuintptr", vuintptr)
sess.Set("vbyte", vbyte)
sess.Set("vrune", vrune)
sess.Set("vfloat32", vfloat32)
sess.Set("vfloat64", vfloat64)
sess.Set("vcomplex64", vcomplex64)
sess.Set("vcomplex128", vcomplex128)
// save session
err = sess.Save()
utils.AssertEqual(t, nil, err)
// get session
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
// get value
utils.AssertEqual(t, vuser, sess.Get("vuser").(User))
utils.AssertEqual(t, vbool, sess.Get("vbool").(bool))
utils.AssertEqual(t, vstring, sess.Get("vstring").(string))
utils.AssertEqual(t, vint, sess.Get("vint").(int))
utils.AssertEqual(t, vint8, sess.Get("vint8").(int8))
utils.AssertEqual(t, vint16, sess.Get("vint16").(int16))
utils.AssertEqual(t, vint32, sess.Get("vint32").(int32))
utils.AssertEqual(t, vint64, sess.Get("vint64").(int64))
utils.AssertEqual(t, vuint, sess.Get("vuint").(uint))
utils.AssertEqual(t, vuint8, sess.Get("vuint8").(uint8))
utils.AssertEqual(t, vuint16, sess.Get("vuint16").(uint16))
utils.AssertEqual(t, vuint32, sess.Get("vuint32").(uint32))
utils.AssertEqual(t, vuint64, sess.Get("vuint64").(uint64))
utils.AssertEqual(t, vuintptr, sess.Get("vuintptr").(uintptr))
utils.AssertEqual(t, vbyte, sess.Get("vbyte").(byte))
utils.AssertEqual(t, vrune, sess.Get("vrune").(rune))
utils.AssertEqual(t, vfloat32, sess.Get("vfloat32").(float32))
utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
}
// go test -run Test_Session_Store_Reset
func Test_Session_Store_Reset(t *testing.T) {
t.Parallel()
@ -167,6 +270,37 @@ func Test_Session_Cookie(t *testing.T) {
sess, _ := store.Get(ctx)
sess.Save()
// cookie should not be set if empty data
utils.AssertEqual(t, 0, len(ctx.Response().Header.PeekCookie(store.CookieName)))
// cookie should be set on Save ( even if empty data )
utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.CookieName)))
}
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
func Benchmark_Session(b *testing.B) {
app, store := fiber.New(), New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.CookieName, "12356789")
b.Run("default", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, _ := store.Get(c)
sess.Set("john", "doe")
_ = sess.Save()
}
})
b.Run("storage", func(b *testing.B) {
store = New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, _ := store.Get(c)
sess.Set("john", "doe")
_ = sess.Save()
}
})
}

View File

@ -1,7 +1,10 @@
package session
import (
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/gotiny"
"github.com/gofiber/fiber/v2/internal/storage/memory"
)
@ -9,8 +12,7 @@ type Store struct {
Config
}
// Storage ErrNotExist
var errNotExist = "key does not exist"
var mux sync.Mutex
func New(config ...Config) *Store {
// Set default config
@ -25,6 +27,13 @@ func New(config ...Config) *Store {
}
}
// RegisterType will allow you to encode/decode custom types
// into any Storage provider
func (s *Store) RegisterType(i interface{}) {
gotiny.Register(i)
}
// Get will get/create a session
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
var fresh bool
@ -42,19 +51,21 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
sess.ctx = c
sess.config = s
sess.id = id
sess.fresh = fresh
// Fetch existing data
if !fresh {
raw, err := s.Storage.Get(id)
// Unmashal if we found data
if err == nil {
if _, err = sess.db.UnmarshalMsg(raw); err != nil {
return nil, err
}
if raw != nil && err == nil {
mux.Lock()
gotiny.Unmarshal(raw, &sess.data)
mux.Unlock()
sess.fresh = false
} else if err.Error() != errNotExist {
// Only return error if it's not ErrNotExist
} else if err != nil {
return nil, err
} else {
sess.fresh = true
}
}

View File

@ -349,9 +349,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router {
if maxAge > 0 {
cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
}
if config[0].CacheDuration != 0 {
fs.CacheDuration = config[0].CacheDuration
}
fs.CacheDuration = config[0].CacheDuration
fs.Compress = config[0].Compress
fs.AcceptByteRange = config[0].ByteRange
fs.GenerateIndexPages = config[0].Browse

View File

@ -16,17 +16,19 @@ import (
)
// AssertEqual checks if values are equal
func AssertEqual(t testing.TB, expected interface{}, actual interface{}, description ...string) {
func AssertEqual(t testing.TB, expected, actual interface{}, description ...string) {
if reflect.DeepEqual(expected, actual) {
return
}
var aType = "<nil>"
var bType = "<nil>"
if reflect.ValueOf(expected).IsValid() {
aType = reflect.TypeOf(expected).Name()
if expected != nil {
aType = fmt.Sprintf("%s", reflect.TypeOf(expected))
}
if reflect.ValueOf(actual).IsValid() {
bType = reflect.TypeOf(actual).Name()
if actual != nil {
bType = fmt.Sprintf("%s", reflect.TypeOf(actual))
}
testName := "AssertEqual"
@ -40,13 +42,11 @@ func AssertEqual(t testing.TB, expected interface{}, actual interface{}, descrip
w := tabwriter.NewWriter(&buf, 0, 0, 5, ' ', 0)
fmt.Fprintf(w, "\nTest:\t%s", testName)
fmt.Fprintf(w, "\nTrace:\t%s:%d", filepath.Base(file), line)
fmt.Fprintf(w, "\nError:\tNot equal")
fmt.Fprintf(w, "\nExpect:\t%v\t[%s]", expected, aType)
fmt.Fprintf(w, "\nResult:\t%v\t[%s]", actual, bType)
if len(description) > 0 {
fmt.Fprintf(w, "\nDescription:\t%s", description[0])
}
fmt.Fprintf(w, "\nExpect:\t%v\t(%s)", expected, aType)
fmt.Fprintf(w, "\nResult:\t%v\t(%s)", actual, bType)
result := ""
if err := w.Flush(); err != nil {
@ -54,6 +54,7 @@ func AssertEqual(t testing.TB, expected interface{}, actual interface{}, descrip
} else {
result = buf.String()
}
if t != nil {
t.Fatal(result)
} else {

View File

@ -6,7 +6,7 @@ package utils
import "testing"
func Test_Utils_AssertEqual(t *testing.T) {
func Test_AssertEqual(t *testing.T) {
t.Parallel()
AssertEqual(nil, []string{}, []string{})
AssertEqual(t, []string{}, []string{})

View File

@ -56,7 +56,7 @@ func TrimBytes(b []byte, cutset byte) []byte {
}
// EqualFold the equivalent of bytes.EqualFold
func EqualsFold(b, s []byte) (equals bool) {
func EqualFoldBytes(b, s []byte) (equals bool) {
n := len(b)
equals = n == len(s)
if equals {

View File

@ -9,7 +9,7 @@ import (
"testing"
)
func Test_Utils_ToLowerBytes(t *testing.T) {
func Test_ToLowerBytes(t *testing.T) {
t.Parallel()
res := ToLowerBytes([]byte("/MY/NAME/IS/:PARAM/*"))
AssertEqual(t, true, bytes.Equal([]byte("/my/name/is/:param/*"), res))
@ -41,7 +41,7 @@ func Benchmark_ToLowerBytes(b *testing.B) {
})
}
func Test_Utils_ToUpperBytes(t *testing.T) {
func Test_ToUpperBytes(t *testing.T) {
t.Parallel()
res := ToUpperBytes([]byte("/my/name/is/:param/*"))
AssertEqual(t, true, bytes.Equal([]byte("/MY/NAME/IS/:PARAM/*"), res))
@ -73,7 +73,7 @@ func Benchmark_ToUpperBytes(b *testing.B) {
})
}
func Test_Utils_TrimRightBytes(t *testing.T) {
func Test_TrimRightBytes(t *testing.T) {
t.Parallel()
res := TrimRightBytes([]byte("/test//////"), '/')
AssertEqual(t, []byte("/test"), res)
@ -99,7 +99,7 @@ func Benchmark_TrimRightBytes(b *testing.B) {
})
}
func Test_Utils_TrimLeftBytes(t *testing.T) {
func Test_TrimLeftBytes(t *testing.T) {
t.Parallel()
res := TrimLeftBytes([]byte("////test/"), '/')
AssertEqual(t, []byte("test/"), res)
@ -123,7 +123,7 @@ func Benchmark_TrimLeftBytes(b *testing.B) {
AssertEqual(b, []byte("foobar"), res)
})
}
func Test_Utils_TrimBytes(t *testing.T) {
func Test_TrimBytes(t *testing.T) {
t.Parallel()
res := TrimBytes([]byte(" test "), ' ')
AssertEqual(t, []byte("test"), res)
@ -151,14 +151,14 @@ func Benchmark_TrimBytes(b *testing.B) {
})
}
func Benchmark_EqualFolds(b *testing.B) {
func Benchmark_EqualFoldBytes(b *testing.B) {
var left = []byte("/RePos/GoFiBer/FibEr/iSsues/187643/CoMmEnts")
var right = []byte("/RePos/goFiber/Fiber/issues/187643/COMMENTS")
var res bool
b.Run("fiber", func(b *testing.B) {
for n := 0; n < b.N; n++ {
res = EqualsFold(left, right)
res = EqualFoldBytes(left, right)
}
AssertEqual(b, true, res)
})
@ -170,18 +170,18 @@ func Benchmark_EqualFolds(b *testing.B) {
})
}
func Test_Utils_EqualsFold(t *testing.T) {
func Test_EqualFoldBytes(t *testing.T) {
t.Parallel()
res := EqualsFold([]byte("/MY/NAME/IS/:PARAM/*"), []byte("/my/name/is/:param/*"))
res := EqualFoldBytes([]byte("/MY/NAME/IS/:PARAM/*"), []byte("/my/name/is/:param/*"))
AssertEqual(t, true, res)
res = EqualsFold([]byte("/MY1/NAME/IS/:PARAM/*"), []byte("/MY1/NAME/IS/:PARAM/*"))
res = EqualFoldBytes([]byte("/MY1/NAME/IS/:PARAM/*"), []byte("/MY1/NAME/IS/:PARAM/*"))
AssertEqual(t, true, res)
res = EqualsFold([]byte("/my2/name/is/:param/*"), []byte("/my2/name"))
res = EqualFoldBytes([]byte("/my2/name/is/:param/*"), []byte("/my2/name"))
AssertEqual(t, false, res)
res = EqualsFold([]byte("/dddddd"), []byte("eeeeee"))
res = EqualFoldBytes([]byte("/dddddd"), []byte("eeeeee"))
AssertEqual(t, false, res)
res = EqualsFold([]byte("/MY3/NAME/IS/:PARAM/*"), []byte("/my3/name/is/:param/*"))
res = EqualFoldBytes([]byte("/MY3/NAME/IS/:PARAM/*"), []byte("/my3/name/is/:param/*"))
AssertEqual(t, true, res)
res = EqualsFold([]byte("/MY4/NAME/IS/:PARAM/*"), []byte("/my4/nAME/IS/:param/*"))
res = EqualFoldBytes([]byte("/MY4/NAME/IS/:PARAM/*"), []byte("/my4/nAME/IS/:param/*"))
AssertEqual(t, true, res)
}

View File

@ -13,6 +13,8 @@ import (
"runtime"
"sync"
"sync/atomic"
googleuuid "github.com/gofiber/fiber/v2/internal/uuid"
)
const toLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
@ -63,6 +65,16 @@ func UUID() string {
return GetString(b)
}
// UUIDv4 returns a Random (Version 4) UUID.
// The strength of the UUIDs is based on the strength of the crypto/rand package.
func UUIDv4() string {
token, err := googleuuid.NewRandom()
if err != nil {
return UUID()
}
return token.String()
}
// FunctionName returns function name
func FunctionName(fn interface{}) string {
t := reflect.ValueOf(fn).Type()

View File

@ -10,26 +10,26 @@ import (
"testing"
)
func Test_Utils_FunctionName(t *testing.T) {
func Test_FunctionName(t *testing.T) {
t.Parallel()
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_Utils_UUID", FunctionName(Test_Utils_UUID))
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_UUID", FunctionName(Test_UUID))
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_Utils_FunctionName.func1", FunctionName(func() {}))
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_FunctionName.func1", FunctionName(func() {}))
var dummyint = 20
AssertEqual(t, "int", FunctionName(dummyint))
}
func Test_Utils_UUID(t *testing.T) {
func Test_UUID(t *testing.T) {
t.Parallel()
res := UUID()
AssertEqual(t, 36, len(res))
AssertEqual(t, true, res != "00000000-0000-0000-0000-000000000000")
}
func Test_Utils_UUID_Concurrency(t *testing.T) {
func Test_UUID_Concurrency(t *testing.T) {
t.Parallel()
iterations := 10000
iterations := 1000
var res string
ch := make(chan string, iterations)
results := make(map[string]string)
@ -45,6 +45,30 @@ func Test_Utils_UUID_Concurrency(t *testing.T) {
AssertEqual(t, iterations, len(results))
}
func Test_UUIDv4(t *testing.T) {
t.Parallel()
res := UUIDv4()
AssertEqual(t, 36, len(res))
AssertEqual(t, true, res != "00000000-0000-0000-0000-000000000000")
}
func Test_UUIDv4_Concurrency(t *testing.T) {
t.Parallel()
iterations := 1000
var res string
ch := make(chan string, iterations)
results := make(map[string]string)
for i := 0; i < iterations; i++ {
go func() {
ch <- UUIDv4()
}()
}
for i := 0; i < iterations; i++ {
res = <-ch
results[res] = res
}
AssertEqual(t, iterations, len(results))
}
// go test -v -run=^$ -bench=Benchmark_UUID -benchmem -count=2
func Benchmark_UUID(b *testing.B) {

View File

@ -28,13 +28,13 @@ func UnsafeBytes(s string) (bs []byte) {
return
}
// SafeString copies a string to make it immutable
func SafeString(s string) string {
// CopyString copies a string to make it immutable
func CopyString(s string) string {
return string(UnsafeBytes(s))
}
// SafeBytes copies a slice to make it immutable
func SafeBytes(b []byte) []byte {
// CopyBytes copies a slice to make it immutable
func CopyBytes(b []byte) []byte {
tmp := make([]byte, len(b))
copy(tmp, b)
return tmp
@ -83,22 +83,3 @@ func ByteSize(bytes uint64) string {
result = strings.TrimSuffix(result, ".0")
return result + unit
}
// Deprecated fn's
// #nosec G103
// GetString returns a string pointer without allocation
func GetString(b []byte) string {
return UnsafeString(b)
}
// #nosec G103
// GetBytes returns a byte pointer without allocation
func GetBytes(s string) []byte {
return UnsafeBytes(s)
}
// ImmutableString copies a string to make it immutable
func ImmutableString(s string) string {
return SafeString(s)
}

View File

@ -6,7 +6,7 @@ package utils
import "testing"
func Test_Utils_GetString(t *testing.T) {
func Test_GetString(t *testing.T) {
t.Parallel()
res := GetString([]byte("Hello, World!"))
AssertEqual(t, "Hello, World!", res)
@ -31,7 +31,7 @@ func Benchmark_GetString(b *testing.B) {
})
}
func Test_Utils_GetBytes(t *testing.T) {
func Test_GetBytes(t *testing.T) {
t.Parallel()
res := GetBytes("Hello, World!")
AssertEqual(t, []byte("Hello, World!"), res)
@ -56,7 +56,7 @@ func Benchmark_GetBytes(b *testing.B) {
})
}
func Test_Utils_ImmutableString(t *testing.T) {
func Test_ImmutableString(t *testing.T) {
t.Parallel()
res := ImmutableString("Hello, World!")
AssertEqual(t, "Hello, World!", res)

33
utils/deprecated.go Normal file
View File

@ -0,0 +1,33 @@
package utils
// #nosec G103
// DEPRECATED, Please use UnsafeString instead
func GetString(b []byte) string {
return UnsafeString(b)
}
// #nosec G103
// DEPRECATED, Please use UnsafeBytes instead
func GetBytes(s string) []byte {
return UnsafeBytes(s)
}
// DEPRECATED, Please use CopyString instead
func ImmutableString(s string) string {
return CopyString(s)
}
// DEPRECATED, please use EqualFoldBytes
func EqualsFold(b, s []byte) (equals bool) {
return EqualFoldBytes(b, s)
}
// DEPRECATED, Please use CopyString instead
func SafeString(s string) string {
return CopyString(s)
}
// DEPRECATED, Please use CopyBytes instead
func SafeBytes(b []byte) []byte {
return CopyBytes(b)
}

View File

@ -10,7 +10,7 @@ import (
"testing"
)
func Test_Utils_GetMIME(t *testing.T) {
func Test_GetMIME(t *testing.T) {
t.Parallel()
res := GetMIME(".json")
AssertEqual(t, "application/json", res)
@ -53,7 +53,7 @@ func Benchmark_GetMIME(b *testing.B) {
})
}
func Test_Utils_StatusMessage(t *testing.T) {
func Test_StatusMessage(t *testing.T) {
t.Parallel()
res := StatusMessage(204)
AssertEqual(t, "No Content", res)

View File

@ -60,3 +60,17 @@ func TrimRight(s string, cutset byte) string {
}
return s[:lenStr]
}
// EqualFold the equivalent of strings.EqualFold
func EqualFold(b, s string) (equals bool) {
n := len(b)
equals = n == len(s)
if equals {
for i := 0; i < n; i++ {
if equals = b[i]|0x20 == s[i]|0x20; !equals {
break
}
}
}
return
}

View File

@ -9,7 +9,7 @@ import (
"testing"
)
func Test_Utils_ToUpper(t *testing.T) {
func Test_ToUpper(t *testing.T) {
t.Parallel()
res := ToUpper("/my/name/is/:param/*")
AssertEqual(t, "/MY/NAME/IS/:PARAM/*", res)
@ -33,7 +33,7 @@ func Benchmark_ToUpper(b *testing.B) {
})
}
func Test_Utils_ToLower(t *testing.T) {
func Test_ToLower(t *testing.T) {
t.Parallel()
res := ToLower("/MY/NAME/IS/:PARAM/*")
AssertEqual(t, "/my/name/is/:param/*", res)
@ -64,7 +64,7 @@ func Benchmark_ToLower(b *testing.B) {
})
}
func Test_Utils_TrimRight(t *testing.T) {
func Test_TrimRight(t *testing.T) {
t.Parallel()
res := TrimRight("/test//////", '/')
AssertEqual(t, "/test", res)
@ -89,7 +89,7 @@ func Benchmark_TrimRight(b *testing.B) {
})
}
func Test_Utils_TrimLeft(t *testing.T) {
func Test_TrimLeft(t *testing.T) {
t.Parallel()
res := TrimLeft("////test/", '/')
AssertEqual(t, "test/", res)
@ -113,7 +113,7 @@ func Benchmark_TrimLeft(b *testing.B) {
AssertEqual(b, "foobar", res)
})
}
func Test_Utils_Trim(t *testing.T) {
func Test_Trim(t *testing.T) {
t.Parallel()
res := Trim(" test ", ' ')
AssertEqual(t, "test", res)
@ -147,3 +147,39 @@ func Benchmark_Trim(b *testing.B) {
AssertEqual(b, "foobar", res)
})
}
// go test -v -run=^$ -bench=Benchmark_EqualFold -benchmem -count=4
func Benchmark_EqualFold(b *testing.B) {
var left = "/RePos/GoFiBer/FibEr/iSsues/187643/CoMmEnts"
var right = "/RePos/goFiber/Fiber/issues/187643/COMMENTS"
var res bool
b.Run("fiber", func(b *testing.B) {
for n := 0; n < b.N; n++ {
res = EqualFold(left, right)
}
AssertEqual(b, true, res)
})
b.Run("default", func(b *testing.B) {
for n := 0; n < b.N; n++ {
res = strings.EqualFold(left, right)
}
AssertEqual(b, true, res)
})
}
func Test_EqualFold(t *testing.T) {
t.Parallel()
res := EqualFold("/MY/NAME/IS/:PARAM/*", "/my/name/is/:param/*")
AssertEqual(t, true, res)
res = EqualFold("/MY1/NAME/IS/:PARAM/*", "/MY1/NAME/IS/:PARAM/*")
AssertEqual(t, true, res)
res = EqualFold("/my2/name/is/:param/*", "/my2/name")
AssertEqual(t, false, res)
res = EqualFold("/dddddd", "eeeeee")
AssertEqual(t, false, res)
res = EqualFold("/MY3/NAME/IS/:PARAM/*", "/my3/name/is/:param/*")
AssertEqual(t, true, res)
res = EqualFold("/MY4/NAME/IS/:PARAM/*", "/my4/nAME/IS/:param/*")
AssertEqual(t, true, res)
}