mirror of
https://github.com/gofiber/fiber.git
synced 2025-05-31 11:52:41 +00:00
Merge branch 'master' into master
This commit is contained in:
commit
494474aebd
1
.github/README.md
vendored
1
.github/README.md
vendored
@ -531,6 +531,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [sujit-baniya/fiber-boilerplate](https://github.com/sujit-baniya/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Contribute
|
||||
|
||||
|
1
.github/README_de.md
vendored
1
.github/README_de.md
vendored
@ -532,6 +532,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Mitwirken
|
||||
|
||||
|
2
.github/README_es.md
vendored
2
.github/README_es.md
vendored
@ -529,6 +529,8 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
|
||||
## 👍 Contribuir
|
||||
|
||||
|
2
.github/README_fr.md
vendored
2
.github/README_fr.md
vendored
@ -532,7 +532,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
## 👍 Contribuer
|
||||
|
||||
Si vous voulez nous remercier et/ou soutenir le développement actif de `Fiber`:
|
||||
|
1
.github/README_he.md
vendored
1
.github/README_he.md
vendored
@ -662,6 +662,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
</div>
|
||||
|
||||
|
1
.github/README_id.md
vendored
1
.github/README_id.md
vendored
@ -532,6 +532,7 @@ Berikut adalah kumpulan _middlewares_ yang dibuat oleh komunitas Fiber, silahkan
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Berkontribusi
|
||||
|
||||
|
17
.github/README_ja.md
vendored
17
.github/README_ja.md
vendored
@ -106,9 +106,10 @@ func main() {
|
||||
|
||||
## ⚙️ インストール
|
||||
|
||||
Make sure you have Go installed ([download](https://golang.org/dl/)). Version `1.14` or higher is required.
|
||||
Goがインストールされていることを確認してください ([ダウンロード](https://golang.org/dl/)). バージョン `1.14` またはそれ以上であることが必要です。
|
||||
|
||||
Initialize your project by creating a folder and then running `go mod init github.com/your/repo` ([learn more](https://blog.golang.org/using-go-modules)) inside the folder. Then install Fiber with the [`go get`](https://golang.org/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) command:
|
||||
フォルダを作成し、フォルダ内で `go mod init github.com/your/repo` ([learn more](https://blog.golang.org/using-go-modules)) を実行してプロジェクトを初期化してください。その後、 Fiber を以下の [`go get`](https://golang.org/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) コマンドでインストールしてください。
|
||||
|
||||
```bash
|
||||
go get -u github.com/gofiber/fiber/v2
|
||||
@ -126,7 +127,7 @@ go get -u github.com/gofiber/fiber/v2
|
||||
- [Template engines](https://github.com/gofiber/template)
|
||||
- [WebSocket support](https://github.com/gofiber/websocket)
|
||||
- [Rate Limiter](https://docs.gofiber.io/middleware#limiter)
|
||||
- Available in [15 languages](https://docs.gofiber.io/)
|
||||
- [15ヶ国語](https://docs.gofiber.io/)で利用可能
|
||||
- [Fiber](https://docs.gofiber.io/)をもっと知る
|
||||
|
||||
## 💡 哲学
|
||||
@ -142,8 +143,6 @@ Fiber は人気の高い Web フレームワークである Expressjs に**イ
|
||||
|
||||
以下に一般的な例をいくつか示します。他のコード例をご覧になりたい場合は、 [Recipes リポジトリ](https://github.com/gofiber/recipes)または[API ドキュメント](https://docs.gofiber.io)にアクセスしてください。
|
||||
|
||||
Listed below are some of the common examples. If you want to see more code examples , please visit our [Recipes repository](https://github.com/gofiber/recipes) or visit our hosted [API documentation](https://docs.gofiber.io).
|
||||
|
||||
#### 📖 [**Basic Routing**](https://docs.gofiber.io/#basic-routing)
|
||||
|
||||
```go
|
||||
@ -485,7 +484,7 @@ func main() {
|
||||
|
||||
## 🧬 Internal Middleware
|
||||
|
||||
Here is a list of middleware that are included within the Fiber framework.
|
||||
以下はFiberフレームワークに含まれるミドルウェアの一覧です。
|
||||
|
||||
| Middleware | Description |
|
||||
| :------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
@ -506,7 +505,7 @@ Here is a list of middleware that are included within the Fiber framework.
|
||||
|
||||
## 🧬 External Middleware
|
||||
|
||||
List of externally hosted middleware modules and maintained by the [Fiber team](https://github.com/orgs/gofiber/people).
|
||||
[Fiber team](https://github.com/orgs/gofiber/people) により管理・運用されているミドルウェアの一覧です。
|
||||
|
||||
| Middleware | Description |
|
||||
| :------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
@ -522,6 +521,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
|
||||
## 🌱 Third Party Middlewares
|
||||
|
||||
This is a list of middlewares that are created by the Fiber community, please create a PR if you want to see yours!
|
||||
これらはFiberのコミュニティーによって作成されたミドルウェアの一覧です。もしあなたのミドルウェアを掲載したい場合はPRを作成してください!
|
||||
|
||||
- [arsmn/fiber-casbin](https://github.com/arsmn/fiber-casbin)
|
||||
- [arsmn/fiber-introspect](https://github.com/arsmn/fiber-introspect)
|
||||
@ -536,6 +536,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 貢献する
|
||||
|
||||
@ -544,11 +545,11 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
1. [GitHub Star](https://github.com/gofiber/fiber/stargazers)をつけてください 。
|
||||
2. [あなたの Twitter で](https://twitter.com/intent/tweet?text=Fiber%20is%20an%20Express%20inspired%20%23web%20%23framework%20built%20on%20top%20of%20Fasthttp%2C%20the%20fastest%20HTTP%20engine%20for%20%23Go.%20Designed%20to%20ease%20things%20up%20for%20%23fast%20development%20with%20zero%20memory%20allocation%20and%20%23performance%20in%20mind%20%F0%9F%9A%80%20https%3A%2F%2Fgithub.com%2Fgofiber%2Ffiber)プロジェクトについてツイートしてください。
|
||||
3. [Medium](https://medium.com/) 、 [Dev.to、](https://dev.to/)または個人のブログでレビューまたはチュートリアルを書いてください。
|
||||
4. Support the project by donating a [cup of coffee](https://buymeacoff.ee/fenny).
|
||||
4. [cup of coffee](https://buymeacoff.ee/fenny)の寄付でプロジェクトを支援しましょう。
|
||||
|
||||
## ☕ Supporters
|
||||
|
||||
Fiber is an open source project that runs on donations to pay the bills e.g. our domain name, gitbook, netlify and serverless hosting. If you want to support Fiber, you can ☕ [**buy a coffee here**](https://buymeacoff.ee/fenny).
|
||||
Fiberはオープンソースプロジェクトで、寄付によってドメイン名やgitbook、 netlify、そしてサーバーレスホスティングなどの費用を賄っています。もしFiberを支援したければ ☕ [**こちらから**](https://buymeacoff.ee/fenny)。
|
||||
|
||||
| | User | Donation |
|
||||
| :--------------------------------------------------------- | :----------------------------------------------- | :------- |
|
||||
|
1
.github/README_ko.md
vendored
1
.github/README_ko.md
vendored
@ -536,6 +536,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 기여
|
||||
|
||||
|
1
.github/README_nl.md
vendored
1
.github/README_nl.md
vendored
@ -536,6 +536,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Bijdragen
|
||||
|
||||
|
1
.github/README_pt.md
vendored
1
.github/README_pt.md
vendored
@ -530,6 +530,7 @@ Esta é uma lista de middlewares criados pela comunidade do Fiber, se quiser ter
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Contribuindo
|
||||
|
||||
|
1
.github/README_ru.md
vendored
1
.github/README_ru.md
vendored
@ -532,6 +532,7 @@ func main() {
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Помощь проекту
|
||||
|
||||
|
1
.github/README_sa.md
vendored
1
.github/README_sa.md
vendored
@ -596,6 +596,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 مساهمة
|
||||
|
||||
|
1
.github/README_tr.md
vendored
1
.github/README_tr.md
vendored
@ -529,6 +529,7 @@ This is a list of middlewares that are created by the Fiber community, please cr
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 Destek
|
||||
|
||||
|
1
.github/README_zh-CN.md
vendored
1
.github/README_zh-CN.md
vendored
@ -531,6 +531,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 贡献
|
||||
|
||||
|
1
.github/README_zh-TW.md
vendored
1
.github/README_zh-TW.md
vendored
@ -532,6 +532,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
|
||||
- [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
|
||||
- [ansrivas/fiberprometheus](https://github.com/ansrivas/fiberprometheus)
|
||||
- [LdDl/fiber-long-poll](https://github.com/LdDl/fiber-long-poll)
|
||||
- [K0enM/fiber_vhost](https://github.com/K0enM/fiber_vhost)
|
||||
|
||||
## 👍 貢獻
|
||||
|
||||
|
44
app.go
44
app.go
@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
// Version of current fiber package
|
||||
const Version = "2.2.0"
|
||||
const Version = "2.2.5"
|
||||
|
||||
// Handler defines a function to serve HTTP requests.
|
||||
type Handler = func(*Ctx) error
|
||||
@ -39,27 +39,27 @@ type Handler = func(*Ctx) error
|
||||
// Map is a shortcut for map[string]interface{}, useful for JSON returns
|
||||
type Map map[string]interface{}
|
||||
|
||||
// Storage interface that is implemented by storage providers for different
|
||||
// middleware packages like cache, limiter, session and csrf
|
||||
// Storage interface for communicating with different database/key-value
|
||||
// providers
|
||||
type Storage interface {
|
||||
// Get retrieves the value for the given key.
|
||||
// If no value is not found it returns ErrNotExit error
|
||||
// Get gets the value for the given key.
|
||||
// It returns ErrNotFound if the storage does not contain the key.
|
||||
Get(key string) ([]byte, error)
|
||||
|
||||
// Set stores the given value for the given key along with a
|
||||
// time-to-live expiration value, 0 means live for ever
|
||||
// The key must not be "" and the empty values are ignored.
|
||||
// Empty key or value will be ignored without an error.
|
||||
Set(key string, val []byte, ttl time.Duration) error
|
||||
|
||||
// Delete deletes the stored value for the given key.
|
||||
// Deleting a non-existing key-value pair does NOT lead to an error.
|
||||
// The key must not be "".
|
||||
// Delete deletes the value for the given key.
|
||||
// It returns no error if the storage does not contain the key,
|
||||
Delete(key string) error
|
||||
|
||||
// Reset the storage
|
||||
// Reset resets the storage and delete all keys.
|
||||
Reset() error
|
||||
|
||||
// Close the storage
|
||||
// Close closes the storage and will stop any running garbage
|
||||
// collectors and open connections.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@ -149,6 +149,8 @@ type Config struct {
|
||||
ETag bool `json:"etag"`
|
||||
|
||||
// Max body size that the server accepts.
|
||||
// -1 will decline any body size
|
||||
//
|
||||
// Default: 4 * 1024 * 1024
|
||||
BodyLimit int `json:"body_limit"`
|
||||
|
||||
@ -352,7 +354,7 @@ func New(config ...Config) *App {
|
||||
}
|
||||
|
||||
// Override default values
|
||||
if app.config.BodyLimit <= 0 {
|
||||
if app.config.BodyLimit == 0 {
|
||||
app.config.BodyLimit = DefaultBodyLimit
|
||||
}
|
||||
if app.config.Concurrency <= 0 {
|
||||
@ -373,13 +375,15 @@ func New(config ...Config) *App {
|
||||
if app.config.ErrorHandler == nil {
|
||||
app.config.ErrorHandler = DefaultErrorHandler
|
||||
}
|
||||
|
||||
// Init app
|
||||
app.init()
|
||||
|
||||
// Return app
|
||||
return app
|
||||
}
|
||||
|
||||
// Mount attaches another app instance as a subrouter along a routing path.
|
||||
// Mount attaches another app instance as a sub-router along a routing path.
|
||||
// It's very useful to split up a large API as many independent routers and
|
||||
// compose them as a single service using Mount.
|
||||
func (app *App) Mount(prefix string, fiber *App) Router {
|
||||
@ -659,7 +663,7 @@ func (app *App) Test(req *http.Request, msTimeout ...int) (resp *http.Response,
|
||||
|
||||
type disableLogger struct{}
|
||||
|
||||
func (dl *disableLogger) Printf(format string, args ...interface{}) {
|
||||
func (dl *disableLogger) Printf(_ string, _ ...interface{}) {
|
||||
// fmt.Println(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
@ -735,7 +739,7 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||
logo += " │ %s │\n"
|
||||
logo += " │ %s │\n"
|
||||
logo += " │ │\n"
|
||||
logo += " │ Handlers %s Threads %s │\n"
|
||||
logo += " │ Handlers %s Processes %s │\n"
|
||||
logo += " │ Prefork .%s PID ....%s │\n"
|
||||
logo += " └───────────────────────────────────────────────────┘"
|
||||
logo += "%s"
|
||||
@ -811,11 +815,16 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||
isPrefork = "Enabled"
|
||||
}
|
||||
|
||||
procs := strconv.Itoa(runtime.GOMAXPROCS(0))
|
||||
if !app.config.Prefork {
|
||||
procs = "1"
|
||||
}
|
||||
|
||||
mainLogo := fmt.Sprintf(logo,
|
||||
cBlack,
|
||||
centerValue(" Fiber v"+Version, 49),
|
||||
center(addr, 49),
|
||||
value(strconv.Itoa(app.handlerCount), 14), value(strconv.Itoa(runtime.GOMAXPROCS(0)), 14),
|
||||
value(strconv.Itoa(app.handlerCount), 14), value(procs, 12),
|
||||
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
|
||||
cReset,
|
||||
)
|
||||
@ -908,6 +917,5 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
|
||||
out = colorable.NewNonColorable(os.Stdout)
|
||||
}
|
||||
|
||||
fmt.Fprintln(out, output)
|
||||
|
||||
_, _ = fmt.Fprintln(out, output)
|
||||
}
|
||||
|
22
ctx.go
22
ctx.go
@ -247,38 +247,42 @@ var decoderPool = &sync.Pool{New: func() interface{} {
|
||||
// BodyParser binds the request body to a struct.
|
||||
// It supports decoding the following content types based on the Content-Type header:
|
||||
// application/json, application/xml, application/x-www-form-urlencoded, multipart/form-data
|
||||
// If none of the content types above are matched, it will return a ErrUnprocessableEntity error
|
||||
func (c *Ctx) BodyParser(out interface{}) error {
|
||||
// Get decoder from pool
|
||||
schemaDecoder := decoderPool.Get().(*schema.Decoder)
|
||||
defer decoderPool.Put(schemaDecoder)
|
||||
|
||||
// Get content-type
|
||||
ctype := getString(c.fasthttp.Request.Header.ContentType())
|
||||
ctype := utils.ToLower(utils.UnsafeString(c.fasthttp.Request.Header.ContentType()))
|
||||
|
||||
// Parse body accordingly
|
||||
if strings.HasPrefix(ctype, MIMEApplicationJSON) {
|
||||
schemaDecoder.SetAliasTag("json")
|
||||
return json.Unmarshal(c.fasthttp.Request.Body(), out)
|
||||
} else if strings.HasPrefix(ctype, MIMEApplicationForm) {
|
||||
}
|
||||
if strings.HasPrefix(ctype, MIMEApplicationForm) {
|
||||
schemaDecoder.SetAliasTag("form")
|
||||
data := make(map[string][]string)
|
||||
c.fasthttp.PostArgs().VisitAll(func(key []byte, val []byte) {
|
||||
data[getString(key)] = append(data[getString(key)], getString(val))
|
||||
data[utils.UnsafeString(key)] = append(data[utils.UnsafeString(key)], utils.UnsafeString(val))
|
||||
})
|
||||
return schemaDecoder.Decode(out, data)
|
||||
} else if strings.HasPrefix(ctype, MIMEMultipartForm) {
|
||||
}
|
||||
if strings.HasPrefix(ctype, MIMEMultipartForm) {
|
||||
schemaDecoder.SetAliasTag("form")
|
||||
data, err := c.fasthttp.MultipartForm()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return schemaDecoder.Decode(out, data.Value)
|
||||
} else if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
|
||||
}
|
||||
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
|
||||
schemaDecoder.SetAliasTag("xml")
|
||||
return xml.Unmarshal(c.fasthttp.Request.Body(), out)
|
||||
}
|
||||
// No suitable content type found
|
||||
return fmt.Errorf("bodyparser: cannot parse content-type: %v", ctype)
|
||||
return ErrUnprocessableEntity
|
||||
}
|
||||
|
||||
// ClearCookie expires a specific cookie by key on the client side.
|
||||
@ -946,7 +950,7 @@ func (c *Ctx) SendFile(file string, compress ...bool) error {
|
||||
})
|
||||
|
||||
// Keep original path for mutable params
|
||||
c.pathOriginal = utils.SafeString(c.pathOriginal)
|
||||
c.pathOriginal = utils.CopyString(c.pathOriginal)
|
||||
// Disable compression
|
||||
if len(compress) <= 0 || !compress[0] {
|
||||
// https://github.com/valyala/fasthttp/blob/master/fs.go#L46
|
||||
@ -1017,7 +1021,7 @@ func (c *Ctx) SendStream(stream io.Reader, size ...int) error {
|
||||
|
||||
// Set sets the response's HTTP header field to the specified key, value.
|
||||
func (c *Ctx) Set(key string, val string) {
|
||||
c.fasthttp.Response.Header.Set(key, removeNewLines(val))
|
||||
c.fasthttp.Response.Header.Set(key, val)
|
||||
}
|
||||
|
||||
func (c *Ctx) setCanonical(key string, val string) {
|
||||
@ -1098,7 +1102,7 @@ func (c *Ctx) WriteString(s string) (int, error) {
|
||||
// XHR returns a Boolean property, that is true, if the request's X-Requested-With header field is XMLHttpRequest,
|
||||
// indicating that the request was issued by a client library (such as jQuery).
|
||||
func (c *Ctx) XHR() bool {
|
||||
return utils.EqualsFold(utils.UnsafeBytes(c.Get(HeaderXRequestedWith)), []byte("xmlhttprequest"))
|
||||
return utils.EqualFoldBytes(utils.UnsafeBytes(c.Get(HeaderXRequestedWith)), []byte("xmlhttprequest"))
|
||||
}
|
||||
|
||||
// prettifyPath ...
|
||||
|
5
go.mod
5
go.mod
@ -3,7 +3,6 @@ module github.com/gofiber/fiber/v2
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.11.0 // indirect
|
||||
github.com/valyala/fasthttp v1.17.0
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1
|
||||
github.com/valyala/fasthttp v1.18.0
|
||||
golang.org/x/sys v0.0.0-20201210223839-7e3030f88018
|
||||
)
|
||||
|
12
go.sum
12
go.sum
@ -2,25 +2,25 @@ github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDa
|
||||
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg=
|
||||
github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg=
|
||||
github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg=
|
||||
github.com/valyala/fasthttp v1.17.0/go.mod h1:jjraHZVbKOXftJfsOYoAjaeygpj5hr8ermTRJNroD7A=
|
||||
github.com/valyala/fasthttp v1.18.0 h1:IV0DdMlatq9QO1Cr6wGJPVW1sV1Q8HvZXAIcjorylyM=
|
||||
github.com/valyala/fasthttp v1.18.0/go.mod h1:jjraHZVbKOXftJfsOYoAjaeygpj5hr8ermTRJNroD7A=
|
||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a h1:0R4NLDRDZX6JcmhJgXi5E4b8Wg84ihbmUKp/GvSPEzc=
|
||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg=
|
||||
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1 h1:a/mKvvZr9Jcc8oKfcmgzyp7OwF73JPWsQLvH1z2Kxck=
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201210223839-7e3030f88018 h1:XKi8B/gRBuTZN1vU9gFsLMm6zVz5FSCDzm8JYACnjy8=
|
||||
golang.org/x/sys v0.0.0-20201210223839-7e3030f88018/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
2
group.go
2
group.go
@ -15,7 +15,7 @@ type Group struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
// Mount attaches another app instance as a subrouter along a routing path.
|
||||
// Mount attaches another app instance as a sub-router along a routing path.
|
||||
// It's very useful to split up a large API as many independent routers and
|
||||
// compose them as a single service using Mount.
|
||||
func (grp *Group) Mount(prefix string, fiber *App) Router {
|
||||
|
29
helpers.go
29
helpers.go
@ -94,29 +94,6 @@ func quoteString(raw string) string {
|
||||
return quoted
|
||||
}
|
||||
|
||||
// removeNewLines will replace `\r` and `\n` with an empty space
|
||||
func removeNewLines(raw string) string {
|
||||
start := 0
|
||||
if start = strings.IndexByte(raw, '\r'); start == -1 {
|
||||
if start = strings.IndexByte(raw, '\n'); start == -1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
bb := bytebufferpool.Get()
|
||||
buf := bb.Bytes()
|
||||
buf = append(buf, raw...)
|
||||
for i := start; i < len(buf); i++ {
|
||||
if buf[i] != '\r' && buf[i] != '\n' {
|
||||
continue
|
||||
}
|
||||
buf[i] = ' '
|
||||
}
|
||||
raw = utils.UnsafeString(buf)
|
||||
bytebufferpool.Put(bb)
|
||||
|
||||
return raw
|
||||
}
|
||||
|
||||
// Scan stack if other methods match the request
|
||||
func methodExist(ctx *Ctx) (exist bool) {
|
||||
for i := 0; i < len(intMethod); i++ {
|
||||
@ -364,9 +341,9 @@ func (c *testConn) Close() error { return nil }
|
||||
|
||||
func (c *testConn) LocalAddr() net.Addr { return testAddr("local-addr") }
|
||||
func (c *testConn) RemoteAddr() net.Addr { return testAddr("remote-addr") }
|
||||
func (c *testConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *testConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *testConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
func (c *testConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c *testConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c *testConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
// getString converts byte slice to a string without memory allocation.
|
||||
var getString = utils.UnsafeString
|
||||
|
@ -16,70 +16,6 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_RemoveNewLines -benchmem -count=4
|
||||
func Benchmark_RemoveNewLines(b *testing.B) {
|
||||
withNL := "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n"
|
||||
withoutNL := "foo Set-Cookie:%20SESSIONID=MaliciousValue "
|
||||
expected := utils.SafeString(withoutNL)
|
||||
var res string
|
||||
|
||||
b.Run("withoutNL", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = removeNewLines(withoutNL)
|
||||
}
|
||||
utils.AssertEqual(b, expected, res)
|
||||
})
|
||||
b.Run("withNL", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = removeNewLines(withNL)
|
||||
}
|
||||
utils.AssertEqual(b, expected, res)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=RemoveNewLines_Bytes -count=3
|
||||
func Test_RemoveNewLines_Bytes(t *testing.T) {
|
||||
app := New()
|
||||
t.Run("Not Status OK", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.SendString("Hello, World!")
|
||||
c.Status(201)
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
|
||||
})
|
||||
|
||||
t.Run("No Body", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
|
||||
})
|
||||
|
||||
t.Run("Has HeaderIfNoneMatch", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.SendString("Hello, World!")
|
||||
c.Request().Header.Set(HeaderIfNoneMatch, `"13-1831710635"`)
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, 304, c.Response().StatusCode())
|
||||
utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag)))
|
||||
utils.AssertEqual(t, "", string(c.Response().Body()))
|
||||
})
|
||||
|
||||
t.Run("No HeaderIfNoneMatch", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.SendString("Hello, World!")
|
||||
setETag(c, false)
|
||||
utils.AssertEqual(t, `"13-1831710635"`, string(c.Response().Header.Peek(HeaderETag)))
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=Test_Utils_ -count=3
|
||||
func Test_Utils_ETag(t *testing.T) {
|
||||
app := New()
|
||||
|
76
internal/encoding/json/README.md
Normal file
76
internal/encoding/json/README.md
Normal file
@ -0,0 +1,76 @@
|
||||
# encoding/json [](https://godoc.org/github.com/segmentio/encoding/json)
|
||||
|
||||
Go package offering a replacement implementation of the standard library's
|
||||
[`encoding/json`](https://golang.org/pkg/encoding/json/) package, with much
|
||||
better performance.
|
||||
|
||||
## Usage
|
||||
|
||||
The exported API of this package mirrors the standard library's
|
||||
[`encoding/json`](https://golang.org/pkg/encoding/json/) package, the only
|
||||
change needed to take advantage of the performance improvements is the import
|
||||
path of the `json` package, from:
|
||||
```go
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
```
|
||||
to
|
||||
```go
|
||||
import (
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
```
|
||||
|
||||
One way to gain higher encoding throughput is to disable HTML escaping.
|
||||
It allows the string encoding to use a much more efficient code path which
|
||||
does not require parsing UTF-8 runes most of the time.
|
||||
|
||||
## Performance Improvements
|
||||
|
||||
The internal implementation uses a fair amount of unsafe operations (untyped
|
||||
code, pointer arithmetic, etc...) to avoid using reflection as much as possible,
|
||||
which is often the reason why serialization code has a large CPU and memory
|
||||
footprint.
|
||||
|
||||
The package aims for zero unnecessary dynamic memory allocations and hot code
|
||||
paths that are mostly free from calls into the reflect package.
|
||||
|
||||
## Compatibility with encoding/json
|
||||
|
||||
This package aims to be a drop-in replacement, therefore it is tested to behave
|
||||
exactly like the standard library's package. However, there are still a few
|
||||
missing features that have not been ported yet:
|
||||
|
||||
- Streaming decoder, currently the `Decoder` implementation offered by the
|
||||
package does not support progressively reading values from a JSON array (unlike
|
||||
the standard library). In our experience this is a very rare use-case, if you
|
||||
need it you're better off sticking to the standard library, or spend a bit of
|
||||
time implementing it in here ;)
|
||||
|
||||
Note that none of those features should result in performance degradations if
|
||||
they were implemented in the package, and we welcome contributions!
|
||||
|
||||
## Trade-offs
|
||||
|
||||
As one would expect, we had to make a couple of trade-offs to achieve greater
|
||||
performance than the standard library, but there were also features that we
|
||||
did not want to give away.
|
||||
|
||||
Other open-source packages offering a reduced CPU and memory footprint usually
|
||||
do so by designing a different API, or require code generation (therefore adding
|
||||
complexity to the build process). These were not acceptable conditions for us,
|
||||
as we were not willing to trade off developer productivity for better runtime
|
||||
performance. To achieve this, we chose to exactly replicate the standard
|
||||
library interfaces and behavior, which meant the package implementation was the
|
||||
only area that we were able to work with. The internals of this package make
|
||||
heavy use of unsafe pointer arithmetics and other performance optimizations,
|
||||
and therefore are not as approachable as typical Go programs. Basically, we put
|
||||
a bigger burden on maintainers to achieve better runtime cost without
|
||||
sacrificing developer productivity.
|
||||
|
||||
For these reasons, we also don't believe that this code should be ported upstream
|
||||
to the standard `encoding/json` package. The standard library has to remain
|
||||
readable and approachable to maximize stability and maintainability, and make
|
||||
projects like this one possible because a high quality reference implementation
|
||||
already exists.
|
@ -340,6 +340,24 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
|
||||
encode: encoder.encodeMapStringRawMessage,
|
||||
decode: decoder.decodeMapStringRawMessage,
|
||||
}
|
||||
|
||||
case k == stringType && v == stringType:
|
||||
return codec{
|
||||
encode: encoder.encodeMapStringString,
|
||||
decode: decoder.decodeMapStringString,
|
||||
}
|
||||
|
||||
case k == stringType && v == stringsType:
|
||||
return codec{
|
||||
encode: encoder.encodeMapStringStringSlice,
|
||||
decode: decoder.decodeMapStringStringSlice,
|
||||
}
|
||||
|
||||
case k == stringType && v == boolType:
|
||||
return codec{
|
||||
encode: encoder.encodeMapStringBool,
|
||||
decode: decoder.decodeMapStringBool,
|
||||
}
|
||||
}
|
||||
|
||||
kc := codec{}
|
||||
@ -1035,6 +1053,7 @@ var (
|
||||
|
||||
numberType = reflect.TypeOf(json.Number(""))
|
||||
stringType = reflect.TypeOf("")
|
||||
stringsType = reflect.TypeOf([]string(nil))
|
||||
bytesType = reflect.TypeOf(([]byte)(nil))
|
||||
durationType = reflect.TypeOf(time.Duration(0))
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
@ -1045,9 +1064,13 @@ var (
|
||||
timePtrType = reflect.PtrTo(timeType)
|
||||
rawMessagePtrType = reflect.PtrTo(rawMessageType)
|
||||
|
||||
sliceInterfaceType = reflect.TypeOf(([]interface{})(nil))
|
||||
mapStringInterfaceType = reflect.TypeOf((map[string]interface{})(nil))
|
||||
mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil))
|
||||
sliceInterfaceType = reflect.TypeOf(([]interface{})(nil))
|
||||
sliceStringType = reflect.TypeOf(([]interface{})(nil))
|
||||
mapStringInterfaceType = reflect.TypeOf((map[string]interface{})(nil))
|
||||
mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil))
|
||||
mapStringStringType = reflect.TypeOf((map[string]string)(nil))
|
||||
mapStringStringSliceType = reflect.TypeOf((map[string][]string)(nil))
|
||||
mapStringBoolType = reflect.TypeOf((map[string]bool)(nil))
|
||||
|
||||
interfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
|
||||
jsonMarshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||
|
@ -535,7 +535,7 @@ func (d decoder) decodeArray(b []byte, p unsafe.Pointer, n int, size uintptr, t
|
||||
if err != nil {
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = t.String() + e.Struct
|
||||
e.Field = strconv.Itoa(i) + "." + e.Field
|
||||
e.Field = d.prependField(strconv.Itoa(i), e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
@ -637,7 +637,7 @@ func (d decoder) decodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = t.String() + e.Struct
|
||||
e.Field = strconv.Itoa(s.len) + "." + e.Field
|
||||
e.Field = d.prependField(strconv.Itoa(s.len), e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
@ -716,7 +716,7 @@ func (d decoder) decodeMap(b []byte, p unsafe.Pointer, t, kt, vt reflect.Type, k
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = "map[" + kt.String() + "]" + vt.String() + "{" + e.Struct + "}"
|
||||
e.Field = fmt.Sprint(k.Interface()) + "." + e.Field
|
||||
e.Field = d.prependField(fmt.Sprint(k.Interface()), e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
@ -797,7 +797,7 @@ func (d decoder) decodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = mapStringInterfaceType.String() + e.Struct
|
||||
e.Field = key + "." + e.Field
|
||||
e.Field = d.prependField(key, e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
@ -878,7 +878,254 @@ func (d decoder) decodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = mapStringRawMessageType.String() + e.Struct
|
||||
e.Field = key + "." + e.Field
|
||||
e.Field = d.prependField(key, e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
m[key] = val
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func (d decoder) decodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) {
|
||||
if hasNullPrefix(b) {
|
||||
*(*unsafe.Pointer)(p) = nil
|
||||
return b[4:], nil
|
||||
}
|
||||
|
||||
if len(b) < 2 || b[0] != '{' {
|
||||
return inputError(b, mapStringStringType)
|
||||
}
|
||||
|
||||
i := 0
|
||||
m := *(*map[string]string)(p)
|
||||
|
||||
if m == nil {
|
||||
m = make(map[string]string, 64)
|
||||
}
|
||||
|
||||
var err error
|
||||
var key string
|
||||
var val string
|
||||
var input = b
|
||||
|
||||
b = b[1:]
|
||||
for {
|
||||
key = ""
|
||||
val = ""
|
||||
|
||||
b = skipSpaces(b)
|
||||
|
||||
if len(b) != 0 && b[0] == '}' {
|
||||
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(unsafe.Pointer(&m))
|
||||
return b[1:], nil
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
if len(b) == 0 {
|
||||
return b, syntaxError(b, "unexpected end of JSON input after object field value")
|
||||
}
|
||||
if b[0] != ',' {
|
||||
return b, syntaxError(b, "expected ',' after object field value but found '%c'", b[0])
|
||||
}
|
||||
b = skipSpaces(b[1:])
|
||||
}
|
||||
|
||||
if hasPrefix(b, "null") {
|
||||
return b, syntaxError(b, "cannot decode object key string from 'null' value")
|
||||
}
|
||||
|
||||
b, err = d.decodeString(b, unsafe.Pointer(&key))
|
||||
if err != nil {
|
||||
return objectKeyError(b, err)
|
||||
}
|
||||
b = skipSpaces(b)
|
||||
|
||||
if len(b) == 0 {
|
||||
return b, syntaxError(b, "unexpected end of JSON input after object field key")
|
||||
}
|
||||
if b[0] != ':' {
|
||||
return b, syntaxError(b, "expected ':' after object field key but found '%c'", b[0])
|
||||
}
|
||||
b = skipSpaces(b[1:])
|
||||
|
||||
b, err = d.decodeString(b, unsafe.Pointer(&val))
|
||||
if err != nil {
|
||||
if _, r, err := parseValue(input); err != nil {
|
||||
return r, err
|
||||
} else {
|
||||
b = r
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = mapStringStringType.String() + e.Struct
|
||||
e.Field = d.prependField(key, e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
m[key] = val
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func (d decoder) decodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) {
|
||||
if hasNullPrefix(b) {
|
||||
*(*unsafe.Pointer)(p) = nil
|
||||
return b[4:], nil
|
||||
}
|
||||
|
||||
if len(b) < 2 || b[0] != '{' {
|
||||
return inputError(b, mapStringStringSliceType)
|
||||
}
|
||||
|
||||
i := 0
|
||||
m := *(*map[string][]string)(p)
|
||||
|
||||
if m == nil {
|
||||
m = make(map[string][]string, 64)
|
||||
}
|
||||
|
||||
var err error
|
||||
var key string
|
||||
var buf []string
|
||||
var input = b
|
||||
var stringSize = unsafe.Sizeof("")
|
||||
|
||||
b = b[1:]
|
||||
for {
|
||||
key = ""
|
||||
buf = buf[:0]
|
||||
|
||||
b = skipSpaces(b)
|
||||
|
||||
if len(b) != 0 && b[0] == '}' {
|
||||
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(unsafe.Pointer(&m))
|
||||
return b[1:], nil
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
if len(b) == 0 {
|
||||
return b, syntaxError(b, "unexpected end of JSON input after object field value")
|
||||
}
|
||||
if b[0] != ',' {
|
||||
return b, syntaxError(b, "expected ',' after object field value but found '%c'", b[0])
|
||||
}
|
||||
b = skipSpaces(b[1:])
|
||||
}
|
||||
|
||||
if hasPrefix(b, "null") {
|
||||
return b, syntaxError(b, "cannot decode object key string from 'null' value")
|
||||
}
|
||||
|
||||
b, err = d.decodeString(b, unsafe.Pointer(&key))
|
||||
if err != nil {
|
||||
return objectKeyError(b, err)
|
||||
}
|
||||
b = skipSpaces(b)
|
||||
|
||||
if len(b) == 0 {
|
||||
return b, syntaxError(b, "unexpected end of JSON input after object field key")
|
||||
}
|
||||
if b[0] != ':' {
|
||||
return b, syntaxError(b, "expected ':' after object field key but found '%c'", b[0])
|
||||
}
|
||||
b = skipSpaces(b[1:])
|
||||
|
||||
b, err = d.decodeSlice(b, unsafe.Pointer(&buf), stringSize, sliceStringType, decoder.decodeString)
|
||||
if err != nil {
|
||||
if _, r, err := parseValue(input); err != nil {
|
||||
return r, err
|
||||
} else {
|
||||
b = r
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = mapStringStringType.String() + e.Struct
|
||||
e.Field = d.prependField(key, e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
val := make([]string, len(buf))
|
||||
copy(val, buf)
|
||||
|
||||
m[key] = val
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func (d decoder) decodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) {
|
||||
if hasNullPrefix(b) {
|
||||
*(*unsafe.Pointer)(p) = nil
|
||||
return b[4:], nil
|
||||
}
|
||||
|
||||
if len(b) < 2 || b[0] != '{' {
|
||||
return inputError(b, mapStringBoolType)
|
||||
}
|
||||
|
||||
i := 0
|
||||
m := *(*map[string]bool)(p)
|
||||
|
||||
if m == nil {
|
||||
m = make(map[string]bool, 64)
|
||||
}
|
||||
|
||||
var err error
|
||||
var key string
|
||||
var val bool
|
||||
var input = b
|
||||
|
||||
b = b[1:]
|
||||
for {
|
||||
key = ""
|
||||
val = false
|
||||
|
||||
b = skipSpaces(b)
|
||||
|
||||
if len(b) != 0 && b[0] == '}' {
|
||||
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(unsafe.Pointer(&m))
|
||||
return b[1:], nil
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
if len(b) == 0 {
|
||||
return b, syntaxError(b, "unexpected end of JSON input after object field value")
|
||||
}
|
||||
if b[0] != ',' {
|
||||
return b, syntaxError(b, "expected ',' after object field value but found '%c'", b[0])
|
||||
}
|
||||
b = skipSpaces(b[1:])
|
||||
}
|
||||
|
||||
if hasPrefix(b, "null") {
|
||||
return b, syntaxError(b, "cannot decode object key string from 'null' value")
|
||||
}
|
||||
|
||||
b, err = d.decodeString(b, unsafe.Pointer(&key))
|
||||
if err != nil {
|
||||
return objectKeyError(b, err)
|
||||
}
|
||||
b = skipSpaces(b)
|
||||
|
||||
if len(b) == 0 {
|
||||
return b, syntaxError(b, "unexpected end of JSON input after object field key")
|
||||
}
|
||||
if b[0] != ':' {
|
||||
return b, syntaxError(b, "expected ':' after object field key but found '%c'", b[0])
|
||||
}
|
||||
b = skipSpaces(b[1:])
|
||||
|
||||
b, err = d.decodeBool(b, unsafe.Pointer(&val))
|
||||
if err != nil {
|
||||
if _, r, err := parseValue(input); err != nil {
|
||||
return r, err
|
||||
} else {
|
||||
b = r
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = mapStringStringType.String() + e.Struct
|
||||
e.Field = d.prependField(key, e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
@ -968,7 +1215,7 @@ func (d decoder) decodeStruct(b []byte, p unsafe.Pointer, st *structType) ([]byt
|
||||
}
|
||||
if e, ok := err.(*UnmarshalTypeError); ok {
|
||||
e.Struct = st.typ.String() + e.Struct
|
||||
e.Field = string(k) + "." + e.Field
|
||||
e.Field = d.prependField(string(k), e.Field)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
@ -1190,3 +1437,10 @@ func (d decoder) decodeTextUnmarshaler(b []byte, p unsafe.Pointer, t reflect.Typ
|
||||
|
||||
return b, &UnmarshalTypeError{Value: value, Type: reflect.PtrTo(t)}
|
||||
}
|
||||
|
||||
func (d decoder) prependField(key, field string) string {
|
||||
if field != "" {
|
||||
return key + "." + field
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
@ -476,6 +476,7 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
// encodeString doesn't return errors so we ignore it here
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&k))
|
||||
b = append(b, ':')
|
||||
|
||||
@ -534,6 +535,224 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) {
|
||||
m := *(*map[string]string)(p)
|
||||
if m == nil {
|
||||
return append(b, "null"...), nil
|
||||
}
|
||||
|
||||
if (e.flags & SortMapKeys) == 0 {
|
||||
// Optimized code path when the program does not need the map keys to be
|
||||
// sorted.
|
||||
b = append(b, '{')
|
||||
|
||||
if len(m) != 0 {
|
||||
var i = 0
|
||||
|
||||
for k, v := range m {
|
||||
if i != 0 {
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
// encodeString never returns an error so we ignore it here
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&k))
|
||||
b = append(b, ':')
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&v))
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
b = append(b, '}')
|
||||
return b, nil
|
||||
}
|
||||
|
||||
s := mapslicePool.Get().(*mapslice)
|
||||
if cap(s.elements) < len(m) {
|
||||
s.elements = make([]element, 0, align(10, uintptr(len(m))))
|
||||
}
|
||||
for key, val := range m {
|
||||
v := val
|
||||
s.elements = append(s.elements, element{key: key, val: &v})
|
||||
}
|
||||
sort.Sort(s)
|
||||
|
||||
b = append(b, '{')
|
||||
|
||||
for i, elem := range s.elements {
|
||||
if i != 0 {
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
// encodeString never returns an error so we ignore it here
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&elem.key))
|
||||
b = append(b, ':')
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(elem.val.(*string)))
|
||||
}
|
||||
|
||||
for i := range s.elements {
|
||||
s.elements[i] = element{}
|
||||
}
|
||||
|
||||
s.elements = s.elements[:0]
|
||||
mapslicePool.Put(s)
|
||||
|
||||
b = append(b, '}')
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) {
|
||||
m := *(*map[string][]string)(p)
|
||||
if m == nil {
|
||||
return append(b, "null"...), nil
|
||||
}
|
||||
|
||||
var stringSize = unsafe.Sizeof("")
|
||||
|
||||
if (e.flags & SortMapKeys) == 0 {
|
||||
// Optimized code path when the program does not need the map keys to be
|
||||
// sorted.
|
||||
b = append(b, '{')
|
||||
|
||||
if len(m) != 0 {
|
||||
var err error
|
||||
var i = 0
|
||||
|
||||
for k, v := range m {
|
||||
if i != 0 {
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&k))
|
||||
b = append(b, ':')
|
||||
|
||||
b, err = e.encodeSlice(b, unsafe.Pointer(&v), stringSize, sliceStringType, encoder.encodeString)
|
||||
if err != nil {
|
||||
return b, err
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
b = append(b, '}')
|
||||
return b, nil
|
||||
}
|
||||
|
||||
s := mapslicePool.Get().(*mapslice)
|
||||
if cap(s.elements) < len(m) {
|
||||
s.elements = make([]element, 0, align(10, uintptr(len(m))))
|
||||
}
|
||||
for key, val := range m {
|
||||
v := val
|
||||
s.elements = append(s.elements, element{key: key, val: &v})
|
||||
}
|
||||
sort.Sort(s)
|
||||
|
||||
var start = len(b)
|
||||
var err error
|
||||
b = append(b, '{')
|
||||
|
||||
for i, elem := range s.elements {
|
||||
if i != 0 {
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&elem.key))
|
||||
b = append(b, ':')
|
||||
|
||||
b, err = e.encodeSlice(b, unsafe.Pointer(elem.val.(*[]string)), stringSize, sliceStringType, encoder.encodeString)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for i := range s.elements {
|
||||
s.elements[i] = element{}
|
||||
}
|
||||
|
||||
s.elements = s.elements[:0]
|
||||
mapslicePool.Put(s)
|
||||
|
||||
if err != nil {
|
||||
return b[:start], err
|
||||
}
|
||||
|
||||
b = append(b, '}')
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (e encoder) encodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) {
|
||||
m := *(*map[string]bool)(p)
|
||||
if m == nil {
|
||||
return append(b, "null"...), nil
|
||||
}
|
||||
|
||||
if (e.flags & SortMapKeys) == 0 {
|
||||
// Optimized code path when the program does not need the map keys to be
|
||||
// sorted.
|
||||
b = append(b, '{')
|
||||
|
||||
if len(m) != 0 {
|
||||
var i = 0
|
||||
|
||||
for k, v := range m {
|
||||
if i != 0 {
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
// encodeString never returns an error so we ignore it here
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&k))
|
||||
if v {
|
||||
b = append(b, ":true"...)
|
||||
} else {
|
||||
b = append(b, ":false"...)
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
b = append(b, '}')
|
||||
return b, nil
|
||||
}
|
||||
|
||||
s := mapslicePool.Get().(*mapslice)
|
||||
if cap(s.elements) < len(m) {
|
||||
s.elements = make([]element, 0, align(10, uintptr(len(m))))
|
||||
}
|
||||
for key, val := range m {
|
||||
s.elements = append(s.elements, element{key: key, val: val})
|
||||
}
|
||||
sort.Sort(s)
|
||||
|
||||
b = append(b, '{')
|
||||
|
||||
for i, elem := range s.elements {
|
||||
if i != 0 {
|
||||
b = append(b, ',')
|
||||
}
|
||||
|
||||
// encodeString never returns an error so we ignore it here
|
||||
b, _ = e.encodeString(b, unsafe.Pointer(&elem.key))
|
||||
if elem.val.(bool) {
|
||||
b = append(b, ":true"...)
|
||||
} else {
|
||||
b = append(b, ":false"...)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range s.elements {
|
||||
s.elements[i] = element{}
|
||||
}
|
||||
|
||||
s.elements = s.elements[:0]
|
||||
mapslicePool.Put(s)
|
||||
|
||||
b = append(b, '}')
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (e encoder) encodeStruct(b []byte, p unsafe.Pointer, st *structType) ([]byte, error) {
|
||||
var start = len(b)
|
||||
var err error
|
||||
|
@ -1,58 +0,0 @@
|
||||
package json
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAppendInt(t *testing.T) {
|
||||
var ints []int64
|
||||
for i := 0; i < 64; i++ {
|
||||
u := uint64(1) << i
|
||||
ints = append(ints, int64(u-1), int64(u), int64(u+1), -int64(u))
|
||||
}
|
||||
|
||||
var std [20]byte
|
||||
var our [20]byte
|
||||
|
||||
for _, i := range ints {
|
||||
expected := strconv.AppendInt(std[:], i, 10)
|
||||
actual := appendInt(our[:], i)
|
||||
if string(expected) != string(actual) {
|
||||
t.Fatalf("appendInt(%d) = %v, expected = %v", i, string(actual), string(expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchStd(b *testing.B, n int64) {
|
||||
var buf [20]byte
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strconv.AppendInt(buf[:0], n, 10)
|
||||
}
|
||||
}
|
||||
|
||||
func benchNew(b *testing.B, n int64) {
|
||||
var buf [20]byte
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
appendInt(buf[:0], n)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendIntStd1(b *testing.B) {
|
||||
benchStd(b, 1)
|
||||
}
|
||||
|
||||
func BenchmarkAppendInt1(b *testing.B) {
|
||||
benchNew(b, 1)
|
||||
}
|
||||
|
||||
func BenchmarkAppendIntStdMinI64(b *testing.B) {
|
||||
benchStd(b, math.MinInt64)
|
||||
}
|
||||
|
||||
func BenchmarkAppendIntMinI64(b *testing.B) {
|
||||
benchNew(b, math.MinInt64)
|
||||
}
|
20
internal/gotiny/LICENSE
Normal file
20
internal/gotiny/LICENSE
Normal file
@ -0,0 +1,20 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 zheng-ji.info
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
205
internal/gotiny/decEngine.go
Normal file
205
internal/gotiny/decEngine.go
Normal file
@ -0,0 +1,205 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type decEng func(*Decoder, unsafe.Pointer) // 解码器
|
||||
|
||||
var (
|
||||
rt2decEng = map[reflect.Type]decEng{
|
||||
reflect.TypeOf((*bool)(nil)).Elem(): decBool,
|
||||
reflect.TypeOf((*int)(nil)).Elem(): decInt,
|
||||
reflect.TypeOf((*int8)(nil)).Elem(): decInt8,
|
||||
reflect.TypeOf((*int16)(nil)).Elem(): decInt16,
|
||||
reflect.TypeOf((*int32)(nil)).Elem(): decInt32,
|
||||
reflect.TypeOf((*int64)(nil)).Elem(): decInt64,
|
||||
reflect.TypeOf((*uint)(nil)).Elem(): decUint,
|
||||
reflect.TypeOf((*uint8)(nil)).Elem(): decUint8,
|
||||
reflect.TypeOf((*uint16)(nil)).Elem(): decUint16,
|
||||
reflect.TypeOf((*uint32)(nil)).Elem(): decUint32,
|
||||
reflect.TypeOf((*uint64)(nil)).Elem(): decUint64,
|
||||
reflect.TypeOf((*uintptr)(nil)).Elem(): decUintptr,
|
||||
reflect.TypeOf((*unsafe.Pointer)(nil)).Elem(): decPointer,
|
||||
reflect.TypeOf((*float32)(nil)).Elem(): decFloat32,
|
||||
reflect.TypeOf((*float64)(nil)).Elem(): decFloat64,
|
||||
reflect.TypeOf((*complex64)(nil)).Elem(): decComplex64,
|
||||
reflect.TypeOf((*complex128)(nil)).Elem(): decComplex128,
|
||||
reflect.TypeOf((*[]byte)(nil)).Elem(): decBytes,
|
||||
reflect.TypeOf((*string)(nil)).Elem(): decString,
|
||||
reflect.TypeOf((*time.Time)(nil)).Elem(): decTime,
|
||||
reflect.TypeOf((*struct{})(nil)).Elem(): decIgnore,
|
||||
reflect.TypeOf(nil): decIgnore,
|
||||
}
|
||||
|
||||
baseDecEngines = []decEng{
|
||||
reflect.Invalid: decIgnore,
|
||||
reflect.Bool: decBool,
|
||||
reflect.Int: decInt,
|
||||
reflect.Int8: decInt8,
|
||||
reflect.Int16: decInt16,
|
||||
reflect.Int32: decInt32,
|
||||
reflect.Int64: decInt64,
|
||||
reflect.Uint: decUint,
|
||||
reflect.Uint8: decUint8,
|
||||
reflect.Uint16: decUint16,
|
||||
reflect.Uint32: decUint32,
|
||||
reflect.Uint64: decUint64,
|
||||
reflect.Uintptr: decUintptr,
|
||||
reflect.UnsafePointer: decPointer,
|
||||
reflect.Float32: decFloat32,
|
||||
reflect.Float64: decFloat64,
|
||||
reflect.Complex64: decComplex64,
|
||||
reflect.Complex128: decComplex128,
|
||||
reflect.String: decString,
|
||||
}
|
||||
decLock sync.RWMutex
|
||||
)
|
||||
|
||||
func getDecEngine(rt reflect.Type) decEng {
|
||||
decLock.RLock()
|
||||
engine := rt2decEng[rt]
|
||||
decLock.RUnlock()
|
||||
if engine != nil {
|
||||
return engine
|
||||
}
|
||||
decLock.Lock()
|
||||
buildDecEngine(rt, &engine)
|
||||
decLock.Unlock()
|
||||
return engine
|
||||
}
|
||||
|
||||
func buildDecEngine(rt reflect.Type, engPtr *decEng) {
|
||||
engine, has := rt2decEng[rt]
|
||||
if has {
|
||||
*engPtr = engine
|
||||
return
|
||||
}
|
||||
|
||||
if _, engine = implementOtherSerializer(rt); engine != nil {
|
||||
rt2decEng[rt] = engine
|
||||
*engPtr = engine
|
||||
return
|
||||
}
|
||||
|
||||
kind := rt.Kind()
|
||||
var eEng decEng
|
||||
switch kind {
|
||||
case reflect.Ptr:
|
||||
et := rt.Elem()
|
||||
defer buildDecEngine(et, &eEng)
|
||||
engine = func(d *Decoder, p unsafe.Pointer) {
|
||||
if d.decIsNotNil() {
|
||||
if isNil(p) {
|
||||
*(*unsafe.Pointer)(p) = unsafe.Pointer(reflect.New(et).Elem().UnsafeAddr())
|
||||
}
|
||||
eEng(d, *(*unsafe.Pointer)(p))
|
||||
} else if !isNil(p) {
|
||||
*(*unsafe.Pointer)(p) = nil
|
||||
}
|
||||
}
|
||||
case reflect.Array:
|
||||
l, et := rt.Len(), rt.Elem()
|
||||
size := et.Size()
|
||||
defer buildDecEngine(et, &eEng)
|
||||
engine = func(d *Decoder, p unsafe.Pointer) {
|
||||
for i := 0; i < l; i++ {
|
||||
eEng(d, unsafe.Pointer(uintptr(p)+uintptr(i)*size))
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
et := rt.Elem()
|
||||
size := et.Size()
|
||||
defer buildDecEngine(et, &eEng)
|
||||
engine = func(d *Decoder, p unsafe.Pointer) {
|
||||
header := (*reflect.SliceHeader)(p)
|
||||
if d.decIsNotNil() {
|
||||
l := d.decLength()
|
||||
if isNil(p) || header.Cap < l {
|
||||
*header = reflect.SliceHeader{Data: reflect.MakeSlice(rt, l, l).Pointer(), Len: l, Cap: l}
|
||||
} else {
|
||||
header.Len = l
|
||||
}
|
||||
for i := 0; i < l; i++ {
|
||||
eEng(d, unsafe.Pointer(header.Data+uintptr(i)*size))
|
||||
}
|
||||
} else if !isNil(p) {
|
||||
*header = reflect.SliceHeader{}
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
kt, vt := rt.Key(), rt.Elem()
|
||||
skt, svt := reflect.SliceOf(kt), reflect.SliceOf(vt)
|
||||
var kEng, vEng decEng
|
||||
defer buildDecEngine(kt, &kEng)
|
||||
defer buildDecEngine(vt, &vEng)
|
||||
engine = func(d *Decoder, p unsafe.Pointer) {
|
||||
if d.decIsNotNil() {
|
||||
l := d.decLength()
|
||||
var v reflect.Value
|
||||
if isNil(p) {
|
||||
v = reflect.MakeMapWithSize(rt, l)
|
||||
*(*unsafe.Pointer)(p) = unsafe.Pointer(v.Pointer())
|
||||
} else {
|
||||
v = reflect.NewAt(rt, p).Elem()
|
||||
}
|
||||
keys, vals := reflect.MakeSlice(skt, l, l), reflect.MakeSlice(svt, l, l)
|
||||
for i := 0; i < l; i++ {
|
||||
key, val := keys.Index(i), vals.Index(i)
|
||||
kEng(d, unsafe.Pointer(key.UnsafeAddr()))
|
||||
vEng(d, unsafe.Pointer(val.UnsafeAddr()))
|
||||
v.SetMapIndex(key, val)
|
||||
}
|
||||
} else if !isNil(p) {
|
||||
*(*unsafe.Pointer)(p) = nil
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
fields, offs := getFieldType(rt, 0)
|
||||
nf := len(fields)
|
||||
fEngines := make([]decEng, nf)
|
||||
defer func() {
|
||||
for i := 0; i < nf; i++ {
|
||||
buildDecEngine(fields[i], &fEngines[i])
|
||||
}
|
||||
}()
|
||||
engine = func(d *Decoder, p unsafe.Pointer) {
|
||||
for i := 0; i < len(fEngines) && i < len(offs); i++ {
|
||||
fEngines[i](d, unsafe.Pointer(uintptr(p)+offs[i]))
|
||||
}
|
||||
}
|
||||
case reflect.Interface:
|
||||
engine = func(d *Decoder, p unsafe.Pointer) {
|
||||
if d.decIsNotNil() {
|
||||
name := ""
|
||||
decString(d, unsafe.Pointer(&name))
|
||||
et, has := name2type[name]
|
||||
if !has {
|
||||
//panic("unknown typ:" + name)
|
||||
fmt.Println("[session] Register this type first with the `RegisterType` method.")
|
||||
}
|
||||
v := reflect.NewAt(rt, p).Elem()
|
||||
var ev reflect.Value
|
||||
if v.IsNil() || v.Elem().Type() != et {
|
||||
ev = reflect.New(et).Elem()
|
||||
} else {
|
||||
ev = v.Elem()
|
||||
}
|
||||
getDecEngine(et)(d, getUnsafePointer(&ev))
|
||||
v.Set(ev)
|
||||
} else if !isNil(p) {
|
||||
*(*unsafe.Pointer)(p) = nil
|
||||
}
|
||||
}
|
||||
case reflect.Chan, reflect.Func:
|
||||
//panic("not support " + rt.String() + " type")
|
||||
default:
|
||||
engine = baseDecEngines[kind]
|
||||
}
|
||||
rt2decEng[rt] = engine
|
||||
*engPtr = engine
|
||||
}
|
161
internal/gotiny/decbase.go
Normal file
161
internal/gotiny/decbase.go
Normal file
@ -0,0 +1,161 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (d *Decoder) decBool() (b bool) {
|
||||
if d.boolBit == 0 {
|
||||
d.boolBit = 1
|
||||
d.boolPos = d.buf[d.index]
|
||||
d.index++
|
||||
}
|
||||
b = d.boolPos&d.boolBit != 0
|
||||
d.boolBit <<= 1
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Decoder) decUint64() uint64 {
|
||||
buf, i := d.buf, d.index
|
||||
x := uint64(buf[i])
|
||||
if x < 0x80 {
|
||||
d.index++
|
||||
return x
|
||||
}
|
||||
x1 := buf[i+1]
|
||||
x += uint64(x1) << 7
|
||||
if x1 < 0x80 {
|
||||
d.index += 2
|
||||
return x - 1<<7
|
||||
}
|
||||
x2 := buf[i+2]
|
||||
x += uint64(x2) << 14
|
||||
if x2 < 0x80 {
|
||||
d.index += 3
|
||||
return x - (1<<7 + 1<<14)
|
||||
}
|
||||
x3 := buf[i+3]
|
||||
x += uint64(x3) << 21
|
||||
if x3 < 0x80 {
|
||||
d.index += 4
|
||||
return x - (1<<7 + 1<<14 + 1<<21)
|
||||
}
|
||||
x4 := buf[i+4]
|
||||
x += uint64(x4) << 28
|
||||
if x4 < 0x80 {
|
||||
d.index += 5
|
||||
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28)
|
||||
}
|
||||
x5 := buf[i+5]
|
||||
x += uint64(x5) << 35
|
||||
if x5 < 0x80 {
|
||||
d.index += 6
|
||||
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35)
|
||||
}
|
||||
x6 := buf[i+6]
|
||||
x += uint64(x6) << 42
|
||||
if x6 < 0x80 {
|
||||
d.index += 7
|
||||
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42)
|
||||
}
|
||||
x7 := buf[i+7]
|
||||
x += uint64(x7) << 49
|
||||
if x7 < 0x80 {
|
||||
d.index += 8
|
||||
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42 + 1<<49)
|
||||
}
|
||||
d.index += 9
|
||||
return x + uint64(buf[i+8])<<56 - (1<<7 + 1<<14 + 1<<21 + 1<<28 + 1<<35 + 1<<42 + 1<<49 + 1<<56)
|
||||
}
|
||||
|
||||
func (d *Decoder) decUint16() uint16 {
|
||||
buf, i := d.buf, d.index
|
||||
x := uint16(buf[i])
|
||||
if x < 0x80 {
|
||||
d.index++
|
||||
return x
|
||||
}
|
||||
x1 := buf[i+1]
|
||||
x += uint16(x1) << 7
|
||||
if x1 < 0x80 {
|
||||
d.index += 2
|
||||
return x - 1<<7
|
||||
}
|
||||
d.index += 3
|
||||
return x + uint16(buf[i+2])<<14 - (1<<7 + 1<<14)
|
||||
}
|
||||
|
||||
func (d *Decoder) decUint32() uint32 {
|
||||
buf, i := d.buf, d.index
|
||||
x := uint32(buf[i])
|
||||
if x < 0x80 {
|
||||
d.index++
|
||||
return x
|
||||
}
|
||||
x1 := buf[i+1]
|
||||
x += uint32(x1) << 7
|
||||
if x1 < 0x80 {
|
||||
d.index += 2
|
||||
return x - 1<<7
|
||||
}
|
||||
x2 := buf[i+2]
|
||||
x += uint32(x2) << 14
|
||||
if x2 < 0x80 {
|
||||
d.index += 3
|
||||
return x - (1<<7 + 1<<14)
|
||||
}
|
||||
x3 := buf[i+3]
|
||||
x += uint32(x3) << 21
|
||||
if x3 < 0x80 {
|
||||
d.index += 4
|
||||
return x - (1<<7 + 1<<14 + 1<<21)
|
||||
}
|
||||
x4 := buf[i+4]
|
||||
x += uint32(x4) << 28
|
||||
d.index += 5
|
||||
return x - (1<<7 + 1<<14 + 1<<21 + 1<<28)
|
||||
}
|
||||
|
||||
func (d *Decoder) decLength() int { return int(d.decUint32()) }
|
||||
func (d *Decoder) decIsNotNil() bool { return d.decBool() }
|
||||
|
||||
func decIgnore(*Decoder, unsafe.Pointer) {}
|
||||
func decBool(d *Decoder, p unsafe.Pointer) { *(*bool)(p) = d.decBool() }
|
||||
func decInt(d *Decoder, p unsafe.Pointer) { *(*int)(p) = int(uint64ToInt64(d.decUint64())) }
|
||||
func decInt8(d *Decoder, p unsafe.Pointer) { *(*int8)(p) = int8(d.buf[d.index]); d.index++ }
|
||||
func decInt16(d *Decoder, p unsafe.Pointer) { *(*int16)(p) = uint16ToInt16(d.decUint16()) }
|
||||
func decInt32(d *Decoder, p unsafe.Pointer) { *(*int32)(p) = uint32ToInt32(d.decUint32()) }
|
||||
func decInt64(d *Decoder, p unsafe.Pointer) { *(*int64)(p) = uint64ToInt64(d.decUint64()) }
|
||||
func decUint(d *Decoder, p unsafe.Pointer) { *(*uint)(p) = uint(d.decUint64()) }
|
||||
func decUint8(d *Decoder, p unsafe.Pointer) { *(*uint8)(p) = d.buf[d.index]; d.index++ }
|
||||
func decUint16(d *Decoder, p unsafe.Pointer) { *(*uint16)(p) = d.decUint16() }
|
||||
func decUint32(d *Decoder, p unsafe.Pointer) { *(*uint32)(p) = d.decUint32() }
|
||||
func decUint64(d *Decoder, p unsafe.Pointer) { *(*uint64)(p) = d.decUint64() }
|
||||
func decUintptr(d *Decoder, p unsafe.Pointer) { *(*uintptr)(p) = uintptr(d.decUint64()) }
|
||||
func decPointer(d *Decoder, p unsafe.Pointer) { *(*uintptr)(p) = uintptr(d.decUint64()) }
|
||||
func decFloat32(d *Decoder, p unsafe.Pointer) { *(*float32)(p) = uint32ToFloat32(d.decUint32()) }
|
||||
func decFloat64(d *Decoder, p unsafe.Pointer) { *(*float64)(p) = uint64ToFloat64(d.decUint64()) }
|
||||
func decTime(d *Decoder, p unsafe.Pointer) { *(*time.Time)(p) = time.Unix(0, int64(d.decUint64())) }
|
||||
func decComplex64(d *Decoder, p unsafe.Pointer) { *(*uint64)(p) = d.decUint64() }
|
||||
func decComplex128(d *Decoder, p unsafe.Pointer) {
|
||||
*(*uint64)(p) = d.decUint64()
|
||||
*(*uint64)(unsafe.Pointer(uintptr(p) + ptr1Size)) = d.decUint64()
|
||||
}
|
||||
|
||||
func decString(d *Decoder, p unsafe.Pointer) {
|
||||
l, val := int(d.decUint32()), (*string)(p)
|
||||
*val = string(d.buf[d.index : d.index+l])
|
||||
d.index += l
|
||||
}
|
||||
|
||||
func decBytes(d *Decoder, p unsafe.Pointer) {
|
||||
bytes := (*[]byte)(p)
|
||||
if d.decIsNotNil() {
|
||||
l := int(d.decUint32())
|
||||
*bytes = d.buf[d.index : d.index+l]
|
||||
d.index += l
|
||||
} else if !isNil(p) {
|
||||
*bytes = nil
|
||||
}
|
||||
}
|
97
internal/gotiny/decoder.go
Normal file
97
internal/gotiny/decoder.go
Normal file
@ -0,0 +1,97 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Decoder struct {
|
||||
buf []byte //buf
|
||||
index int //下一个要使用的字节在buf中的下标
|
||||
boolPos byte //下一次要读取的bool在buf中的下标,即buf[boolPos]
|
||||
boolBit byte //下一次要读取的bool的buf[boolPos]中的bit位
|
||||
|
||||
engines []decEng //解码器集合
|
||||
length int //解码器数量
|
||||
}
|
||||
|
||||
func Unmarshal(buf []byte, is ...interface{}) int {
|
||||
return NewDecoderWithPtr(is...).Decode(buf, is...)
|
||||
}
|
||||
|
||||
func NewDecoderWithPtr(is ...interface{}) *Decoder {
|
||||
l := len(is)
|
||||
engines := make([]decEng, l)
|
||||
for i := 0; i < l; i++ {
|
||||
rt := reflect.TypeOf(is[i])
|
||||
if rt.Kind() != reflect.Ptr {
|
||||
panic("must a pointer type!")
|
||||
}
|
||||
engines[i] = getDecEngine(rt.Elem())
|
||||
}
|
||||
return &Decoder{
|
||||
length: l,
|
||||
engines: engines,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDecoder(is ...interface{}) *Decoder {
|
||||
l := len(is)
|
||||
engines := make([]decEng, l)
|
||||
for i := 0; i < l; i++ {
|
||||
engines[i] = getDecEngine(reflect.TypeOf(is[i]))
|
||||
}
|
||||
return &Decoder{
|
||||
length: l,
|
||||
engines: engines,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDecoderWithType(ts ...reflect.Type) *Decoder {
|
||||
l := len(ts)
|
||||
des := make([]decEng, l)
|
||||
for i := 0; i < l; i++ {
|
||||
des[i] = getDecEngine(ts[i])
|
||||
}
|
||||
return &Decoder{
|
||||
length: l,
|
||||
engines: des,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) reset() int {
|
||||
index := d.index
|
||||
d.index = 0
|
||||
d.boolPos = 0
|
||||
d.boolBit = 0
|
||||
return index
|
||||
}
|
||||
|
||||
// is is pointer of variable
|
||||
func (d *Decoder) Decode(buf []byte, is ...interface{}) int {
|
||||
d.buf = buf
|
||||
engines := d.engines
|
||||
for i := 0; i < len(engines) && i < len(is); i++ {
|
||||
engines[i](d, (*[2]unsafe.Pointer)(unsafe.Pointer(&is[i]))[1])
|
||||
}
|
||||
return d.reset()
|
||||
}
|
||||
|
||||
// ps is a unsafe.Pointer of the variable
|
||||
func (d *Decoder) DecodePtr(buf []byte, ps ...unsafe.Pointer) int {
|
||||
d.buf = buf
|
||||
engines := d.engines
|
||||
for i := 0; i < len(engines) && i < len(ps); i++ {
|
||||
engines[i](d, ps[i])
|
||||
}
|
||||
return d.reset()
|
||||
}
|
||||
|
||||
func (d *Decoder) DecodeValue(buf []byte, vs ...reflect.Value) int {
|
||||
d.buf = buf
|
||||
engines := d.engines
|
||||
for i := 0; i < len(engines) && i < len(vs); i++ {
|
||||
engines[i](d, unsafe.Pointer(vs[i].UnsafeAddr()))
|
||||
}
|
||||
return d.reset()
|
||||
}
|
196
internal/gotiny/encEngine.go
Normal file
196
internal/gotiny/encEngine.go
Normal file
@ -0,0 +1,196 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type encEng func(*Encoder, unsafe.Pointer) //编码器
|
||||
|
||||
var (
|
||||
rt2encEng = map[reflect.Type]encEng{
|
||||
reflect.TypeOf((*bool)(nil)).Elem(): encBool,
|
||||
reflect.TypeOf((*int)(nil)).Elem(): encInt,
|
||||
reflect.TypeOf((*int8)(nil)).Elem(): encInt8,
|
||||
reflect.TypeOf((*int16)(nil)).Elem(): encInt16,
|
||||
reflect.TypeOf((*int32)(nil)).Elem(): encInt32,
|
||||
reflect.TypeOf((*int64)(nil)).Elem(): encInt64,
|
||||
reflect.TypeOf((*uint)(nil)).Elem(): encUint,
|
||||
reflect.TypeOf((*uint8)(nil)).Elem(): encUint8,
|
||||
reflect.TypeOf((*uint16)(nil)).Elem(): encUint16,
|
||||
reflect.TypeOf((*uint32)(nil)).Elem(): encUint32,
|
||||
reflect.TypeOf((*uint64)(nil)).Elem(): encUint64,
|
||||
reflect.TypeOf((*uintptr)(nil)).Elem(): encUintptr,
|
||||
reflect.TypeOf((*unsafe.Pointer)(nil)).Elem(): encPointer,
|
||||
reflect.TypeOf((*float32)(nil)).Elem(): encFloat32,
|
||||
reflect.TypeOf((*float64)(nil)).Elem(): encFloat64,
|
||||
reflect.TypeOf((*complex64)(nil)).Elem(): encComplex64,
|
||||
reflect.TypeOf((*complex128)(nil)).Elem(): encComplex128,
|
||||
reflect.TypeOf((*[]byte)(nil)).Elem(): encBytes,
|
||||
reflect.TypeOf((*string)(nil)).Elem(): encString,
|
||||
reflect.TypeOf((*time.Time)(nil)).Elem(): encTime,
|
||||
reflect.TypeOf((*struct{})(nil)).Elem(): encIgnore,
|
||||
reflect.TypeOf(nil): encIgnore,
|
||||
}
|
||||
|
||||
encEngines = [...]encEng{
|
||||
reflect.Invalid: encIgnore,
|
||||
reflect.Bool: encBool,
|
||||
reflect.Int: encInt,
|
||||
reflect.Int8: encInt8,
|
||||
reflect.Int16: encInt16,
|
||||
reflect.Int32: encInt32,
|
||||
reflect.Int64: encInt64,
|
||||
reflect.Uint: encUint,
|
||||
reflect.Uint8: encUint8,
|
||||
reflect.Uint16: encUint16,
|
||||
reflect.Uint32: encUint32,
|
||||
reflect.Uint64: encUint64,
|
||||
reflect.Uintptr: encUintptr,
|
||||
reflect.UnsafePointer: encPointer,
|
||||
reflect.Float32: encFloat32,
|
||||
reflect.Float64: encFloat64,
|
||||
reflect.Complex64: encComplex64,
|
||||
reflect.Complex128: encComplex128,
|
||||
reflect.String: encString,
|
||||
}
|
||||
|
||||
encLock sync.RWMutex
|
||||
)
|
||||
|
||||
func UnusedUnixNanoEncodeTimeType() {
|
||||
delete(rt2encEng, reflect.TypeOf((*time.Time)(nil)).Elem())
|
||||
delete(rt2decEng, reflect.TypeOf((*time.Time)(nil)).Elem())
|
||||
}
|
||||
|
||||
func getEncEngine(rt reflect.Type) encEng {
|
||||
encLock.RLock()
|
||||
engine := rt2encEng[rt]
|
||||
encLock.RUnlock()
|
||||
if engine != nil {
|
||||
return engine
|
||||
}
|
||||
encLock.Lock()
|
||||
buildEncEngine(rt, &engine)
|
||||
encLock.Unlock()
|
||||
return engine
|
||||
}
|
||||
|
||||
func buildEncEngine(rt reflect.Type, engPtr *encEng) {
|
||||
engine := rt2encEng[rt]
|
||||
if engine != nil {
|
||||
*engPtr = engine
|
||||
return
|
||||
}
|
||||
|
||||
if engine, _ = implementOtherSerializer(rt); engine != nil {
|
||||
rt2encEng[rt] = engine
|
||||
*engPtr = engine
|
||||
return
|
||||
}
|
||||
|
||||
kind := rt.Kind()
|
||||
var eEng encEng
|
||||
switch kind {
|
||||
case reflect.Ptr:
|
||||
defer buildEncEngine(rt.Elem(), &eEng)
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
isNotNil := !isNil(p)
|
||||
e.encIsNotNil(isNotNil)
|
||||
if isNotNil {
|
||||
eEng(e, *(*unsafe.Pointer)(p))
|
||||
}
|
||||
}
|
||||
case reflect.Array:
|
||||
et, l := rt.Elem(), rt.Len()
|
||||
defer buildEncEngine(et, &eEng)
|
||||
size := et.Size()
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
for i := 0; i < l; i++ {
|
||||
eEng(e, unsafe.Pointer(uintptr(p)+uintptr(i)*size))
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
et := rt.Elem()
|
||||
size := et.Size()
|
||||
defer buildEncEngine(et, &eEng)
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
isNotNil := !isNil(p)
|
||||
e.encIsNotNil(isNotNil)
|
||||
if isNotNil {
|
||||
header := (*reflect.SliceHeader)(p)
|
||||
l := header.Len
|
||||
e.encLength(l)
|
||||
for i := 0; i < l; i++ {
|
||||
eEng(e, unsafe.Pointer(header.Data+uintptr(i)*size))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
var kEng encEng
|
||||
defer buildEncEngine(rt.Key(), &kEng)
|
||||
defer buildEncEngine(rt.Elem(), &eEng)
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
isNotNil := !isNil(p)
|
||||
e.encIsNotNil(isNotNil)
|
||||
if isNotNil {
|
||||
v := reflect.NewAt(rt, p).Elem()
|
||||
e.encLength(v.Len())
|
||||
keys := v.MapKeys()
|
||||
for i := 0; i < len(keys); i++ {
|
||||
val := v.MapIndex(keys[i])
|
||||
kEng(e, getUnsafePointer(&keys[i]))
|
||||
eEng(e, getUnsafePointer(&val))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
fields, offs := getFieldType(rt, 0)
|
||||
nf := len(fields)
|
||||
fEngines := make([]encEng, nf)
|
||||
defer func() {
|
||||
for i := 0; i < nf; i++ {
|
||||
buildEncEngine(fields[i], &fEngines[i])
|
||||
}
|
||||
}()
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
for i := 0; i < len(fEngines) && i < len(offs); i++ {
|
||||
fEngines[i](e, unsafe.Pointer(uintptr(p)+offs[i]))
|
||||
}
|
||||
}
|
||||
case reflect.Interface:
|
||||
if rt.NumMethod() > 0 {
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
isNotNil := !isNil(p)
|
||||
e.encIsNotNil(isNotNil)
|
||||
if isNotNil {
|
||||
v := reflect.ValueOf(*(*interface {
|
||||
M()
|
||||
})(p))
|
||||
et := v.Type()
|
||||
e.encString(getNameOfType(et))
|
||||
getEncEngine(et)(e, getUnsafePointer(&v))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
engine = func(e *Encoder, p unsafe.Pointer) {
|
||||
isNotNil := !isNil(p)
|
||||
e.encIsNotNil(isNotNil)
|
||||
if isNotNil {
|
||||
v := reflect.ValueOf(*(*interface{})(p))
|
||||
et := v.Type()
|
||||
e.encString(getNameOfType(et))
|
||||
getEncEngine(et)(e, getUnsafePointer(&v))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Chan, reflect.Func:
|
||||
//panic("not support " + rt.String() + " type")
|
||||
default:
|
||||
engine = encEngines[kind]
|
||||
}
|
||||
rt2encEng[rt] = engine
|
||||
*engPtr = engine
|
||||
}
|
108
internal/gotiny/encbase.go
Normal file
108
internal/gotiny/encbase.go
Normal file
@ -0,0 +1,108 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (e *Encoder) encBool(v bool) {
|
||||
if e.boolBit == 0 {
|
||||
e.boolPos = len(e.buf)
|
||||
e.buf = append(e.buf, 0)
|
||||
e.boolBit = 1
|
||||
}
|
||||
if v {
|
||||
e.buf[e.boolPos] |= e.boolBit
|
||||
}
|
||||
e.boolBit <<= 1
|
||||
}
|
||||
|
||||
func (e *Encoder) encUint64(v uint64) {
|
||||
switch {
|
||||
case v < 1<<7-1:
|
||||
e.buf = append(e.buf, byte(v))
|
||||
case v < 1<<14-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7))
|
||||
case v < 1<<21-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14))
|
||||
case v < 1<<28-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21))
|
||||
case v < 1<<35-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28))
|
||||
case v < 1<<42-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35))
|
||||
case v < 1<<49-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42))
|
||||
case v < 1<<56-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)|0x80, byte(v>>49))
|
||||
default:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28)|0x80, byte(v>>35)|0x80, byte(v>>42)|0x80, byte(v>>49)|0x80, byte(v>>56))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Encoder) encUint16(v uint16) {
|
||||
if v < 1<<7-1 {
|
||||
e.buf = append(e.buf, byte(v))
|
||||
} else if v < 1<<14-1 {
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7))
|
||||
} else {
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Encoder) encUint32(v uint32) {
|
||||
switch {
|
||||
case v < 1<<7-1:
|
||||
e.buf = append(e.buf, byte(v))
|
||||
case v < 1<<14-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7))
|
||||
case v < 1<<21-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14))
|
||||
case v < 1<<28-1:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21))
|
||||
default:
|
||||
e.buf = append(e.buf, byte(v)|0x80, byte(v>>7)|0x80, byte(v>>14)|0x80, byte(v>>21)|0x80, byte(v>>28))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Encoder) encLength(v int) { e.encUint32(uint32(v)) }
|
||||
func (e *Encoder) encString(s string) { e.encUint32(uint32(len(s))); e.buf = append(e.buf, s...) }
|
||||
func (e *Encoder) encIsNotNil(v bool) { e.encBool(v) }
|
||||
|
||||
func encIgnore(*Encoder, unsafe.Pointer) {}
|
||||
func encBool(e *Encoder, p unsafe.Pointer) { e.encBool(*(*bool)(p)) }
|
||||
func encInt(e *Encoder, p unsafe.Pointer) { e.encUint64(int64ToUint64(int64(*(*int)(p)))) }
|
||||
func encInt8(e *Encoder, p unsafe.Pointer) { e.buf = append(e.buf, *(*uint8)(p)) }
|
||||
func encInt16(e *Encoder, p unsafe.Pointer) { e.encUint16(int16ToUint16(*(*int16)(p))) }
|
||||
func encInt32(e *Encoder, p unsafe.Pointer) { e.encUint32(int32ToUint32(*(*int32)(p))) }
|
||||
func encInt64(e *Encoder, p unsafe.Pointer) { e.encUint64(int64ToUint64(*(*int64)(p))) }
|
||||
func encUint8(e *Encoder, p unsafe.Pointer) { e.buf = append(e.buf, *(*uint8)(p)) }
|
||||
func encUint16(e *Encoder, p unsafe.Pointer) { e.encUint16(*(*uint16)(p)) }
|
||||
func encUint32(e *Encoder, p unsafe.Pointer) { e.encUint32(*(*uint32)(p)) }
|
||||
func encUint64(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uint64)(p))) }
|
||||
func encUint(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uint)(p))) }
|
||||
func encUintptr(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uintptr)(p))) }
|
||||
func encPointer(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64(*(*uintptr)(p))) }
|
||||
func encFloat32(e *Encoder, p unsafe.Pointer) { e.encUint32(float32ToUint32(p)) }
|
||||
func encFloat64(e *Encoder, p unsafe.Pointer) { e.encUint64(float64ToUint64(p)) }
|
||||
func encString(e *Encoder, p unsafe.Pointer) {
|
||||
s := *(*string)(p)
|
||||
e.encUint32(uint32(len(s)))
|
||||
e.buf = append(e.buf, s...)
|
||||
}
|
||||
func encTime(e *Encoder, p unsafe.Pointer) { e.encUint64(uint64((*time.Time)(p).UnixNano())) }
|
||||
func encComplex64(e *Encoder, p unsafe.Pointer) { e.encUint64(*(*uint64)(p)) }
|
||||
func encComplex128(e *Encoder, p unsafe.Pointer) {
|
||||
e.encUint64(*(*uint64)(p))
|
||||
e.encUint64(*(*uint64)(unsafe.Pointer(uintptr(p) + ptr1Size)))
|
||||
}
|
||||
|
||||
func encBytes(e *Encoder, p unsafe.Pointer) {
|
||||
isNotNil := !isNil(p)
|
||||
e.encIsNotNil(isNotNil)
|
||||
if isNotNil {
|
||||
buf := *(*[]byte)(p)
|
||||
e.encLength(len(buf))
|
||||
e.buf = append(e.buf, buf...)
|
||||
}
|
||||
}
|
103
internal/gotiny/encoder.go
Normal file
103
internal/gotiny/encoder.go
Normal file
@ -0,0 +1,103 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Encoder struct {
|
||||
buf []byte //编码目的数组
|
||||
off int
|
||||
boolPos int //下一次要设置的bool在buf中的下标,即buf[boolPos]
|
||||
boolBit byte //下一次要设置的bool的buf[boolPos]中的bit位
|
||||
|
||||
engines []encEng
|
||||
length int
|
||||
}
|
||||
|
||||
func Marshal(is ...interface{}) []byte {
|
||||
return NewEncoderWithPtr(is...).Encode(is...)
|
||||
}
|
||||
|
||||
// 创建一个编码ps 指向类型的编码器
|
||||
func NewEncoderWithPtr(ps ...interface{}) *Encoder {
|
||||
l := len(ps)
|
||||
engines := make([]encEng, l)
|
||||
for i := 0; i < l; i++ {
|
||||
rt := reflect.TypeOf(ps[i])
|
||||
if rt.Kind() != reflect.Ptr {
|
||||
panic("must a pointer type!")
|
||||
}
|
||||
engines[i] = getEncEngine(rt.Elem())
|
||||
}
|
||||
return &Encoder{
|
||||
length: l,
|
||||
engines: engines,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建一个编码is 类型的编码器
|
||||
func NewEncoder(is ...interface{}) *Encoder {
|
||||
l := len(is)
|
||||
engines := make([]encEng, l)
|
||||
for i := 0; i < l; i++ {
|
||||
engines[i] = getEncEngine(reflect.TypeOf(is[i]))
|
||||
}
|
||||
return &Encoder{
|
||||
length: l,
|
||||
engines: engines,
|
||||
}
|
||||
}
|
||||
|
||||
func NewEncoderWithType(ts ...reflect.Type) *Encoder {
|
||||
l := len(ts)
|
||||
engines := make([]encEng, l)
|
||||
for i := 0; i < l; i++ {
|
||||
engines[i] = getEncEngine(ts[i])
|
||||
}
|
||||
return &Encoder{
|
||||
length: l,
|
||||
engines: engines,
|
||||
}
|
||||
}
|
||||
|
||||
// 入参是要编码值的指针
|
||||
func (e *Encoder) Encode(is ...interface{}) []byte {
|
||||
engines := e.engines
|
||||
for i := 0; i < len(engines) && i < len(is); i++ {
|
||||
engines[i](e, (*[2]unsafe.Pointer)(unsafe.Pointer(&is[i]))[1])
|
||||
}
|
||||
return e.reset()
|
||||
}
|
||||
|
||||
// 入参是要编码的值得unsafe.Pointer 指针
|
||||
func (e *Encoder) EncodePtr(ps ...unsafe.Pointer) []byte {
|
||||
engines := e.engines
|
||||
for i := 0; i < len(engines) && i < len(ps); i++ {
|
||||
engines[i](e, ps[i])
|
||||
}
|
||||
return e.reset()
|
||||
}
|
||||
|
||||
// vs 是持有要编码的值
|
||||
func (e *Encoder) EncodeValue(vs ...reflect.Value) []byte {
|
||||
engines := e.engines
|
||||
for i := 0; i < len(engines) && i < len(vs); i++ {
|
||||
engines[i](e, getUnsafePointer(&vs[i]))
|
||||
}
|
||||
return e.reset()
|
||||
}
|
||||
|
||||
// 编码产生的数据将append到buf上
|
||||
func (e *Encoder) AppendTo(buf []byte) {
|
||||
e.off = len(buf)
|
||||
e.buf = buf
|
||||
}
|
||||
|
||||
func (e *Encoder) reset() []byte {
|
||||
buf := e.buf
|
||||
e.buf = buf[:e.off]
|
||||
e.boolBit = 0
|
||||
e.boolPos = 0
|
||||
return buf
|
||||
}
|
144
internal/gotiny/register.go
Normal file
144
internal/gotiny/register.go
Normal file
@ -0,0 +1,144 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
type2name = map[reflect.Type]string{}
|
||||
name2type = map[string]reflect.Type{}
|
||||
)
|
||||
|
||||
func GetName(obj interface{}) string {
|
||||
return GetNameByType(reflect.TypeOf(obj))
|
||||
}
|
||||
func GetNameByType(rt reflect.Type) string {
|
||||
return string(getName([]byte(nil), rt))
|
||||
}
|
||||
|
||||
func getName(prefix []byte, rt reflect.Type) []byte {
|
||||
if rt == nil || rt.Kind() == reflect.Invalid {
|
||||
return append(prefix, []byte("<nil>")...)
|
||||
}
|
||||
if rt.Name() == "" { //未命名的,组合类型
|
||||
switch rt.Kind() {
|
||||
case reflect.Ptr:
|
||||
return getName(append(prefix, '*'), rt.Elem())
|
||||
case reflect.Array:
|
||||
return getName(append(prefix, "["+strconv.Itoa(rt.Len())+"]"...), rt.Elem())
|
||||
case reflect.Slice:
|
||||
return getName(append(prefix, '[', ']'), rt.Elem())
|
||||
case reflect.Struct:
|
||||
prefix = append(prefix, "struct {"...)
|
||||
nf := rt.NumField()
|
||||
if nf > 0 {
|
||||
prefix = append(prefix, ' ')
|
||||
}
|
||||
for i := 0; i < nf; i++ {
|
||||
field := rt.Field(i)
|
||||
if field.Anonymous {
|
||||
prefix = getName(prefix, field.Type)
|
||||
} else {
|
||||
prefix = getName(append(prefix, field.Name+" "...), field.Type)
|
||||
}
|
||||
if i != nf-1 {
|
||||
prefix = append(prefix, ';', ' ')
|
||||
} else {
|
||||
prefix = append(prefix, ' ')
|
||||
}
|
||||
}
|
||||
return append(prefix, '}')
|
||||
case reflect.Map:
|
||||
return getName(append(getName(append(prefix, "map["...), rt.Key()), ']'), rt.Elem())
|
||||
case reflect.Interface:
|
||||
prefix = append(prefix, "interface {"...)
|
||||
nm := rt.NumMethod()
|
||||
if nm > 0 {
|
||||
prefix = append(prefix, ' ')
|
||||
}
|
||||
for i := 0; i < nm; i++ {
|
||||
method := rt.Method(i)
|
||||
fn := getName([]byte(nil), method.Type)
|
||||
prefix = append(prefix, method.Name+string(fn[4:])...)
|
||||
if i != nm-1 {
|
||||
prefix = append(prefix, ';', ' ')
|
||||
} else {
|
||||
prefix = append(prefix, ' ')
|
||||
}
|
||||
}
|
||||
return append(prefix, '}')
|
||||
case reflect.Func:
|
||||
prefix = append(prefix, "func("...)
|
||||
for i := 0; i < rt.NumIn(); i++ {
|
||||
prefix = getName(prefix, rt.In(i))
|
||||
if i != rt.NumIn()-1 {
|
||||
prefix = append(prefix, ',', ' ')
|
||||
}
|
||||
}
|
||||
prefix = append(prefix, ')')
|
||||
no := rt.NumOut()
|
||||
if no > 0 {
|
||||
prefix = append(prefix, ' ')
|
||||
}
|
||||
if no > 1 {
|
||||
prefix = append(prefix, '(')
|
||||
}
|
||||
for i := 0; i < no; i++ {
|
||||
prefix = getName(prefix, rt.Out(i))
|
||||
if i != no-1 {
|
||||
prefix = append(prefix, ',', ' ')
|
||||
}
|
||||
}
|
||||
if no > 1 {
|
||||
prefix = append(prefix, ')')
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
|
||||
if rt.PkgPath() == "" {
|
||||
prefix = append(prefix, rt.Name()...)
|
||||
} else {
|
||||
prefix = append(prefix, rt.PkgPath()+"."+rt.Name()...)
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func getNameOfType(rt reflect.Type) string {
|
||||
if name, has := type2name[rt]; has {
|
||||
return name
|
||||
} else {
|
||||
return registerType(rt)
|
||||
}
|
||||
}
|
||||
|
||||
func Register(i interface{}) string {
|
||||
return registerType(reflect.TypeOf(i))
|
||||
}
|
||||
|
||||
func registerType(rt reflect.Type) string {
|
||||
name := GetNameByType(rt)
|
||||
RegisterName(name, rt)
|
||||
return name
|
||||
}
|
||||
|
||||
func RegisterName(name string, rt reflect.Type) {
|
||||
if name == "" {
|
||||
panic("attempt to register empty name")
|
||||
}
|
||||
|
||||
if rt == nil || rt.Kind() == reflect.Invalid {
|
||||
panic("attempt to register nil type or invalid type")
|
||||
}
|
||||
|
||||
if _, has := type2name[rt]; has {
|
||||
panic("gotiny: registering duplicate types for " + GetNameByType(rt))
|
||||
}
|
||||
|
||||
if _, has := name2type[name]; has {
|
||||
panic("gotiny: registering name" + name + " is exist")
|
||||
}
|
||||
name2type[name] = rt
|
||||
type2name[rt] = name
|
||||
}
|
57
internal/gotiny/unsafe.go
Normal file
57
internal/gotiny/unsafe.go
Normal file
@ -0,0 +1,57 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
kindDirectIface = 1 << 5
|
||||
)
|
||||
|
||||
// rtype is the common implementation of most values.
|
||||
// It is embedded in other struct types.
|
||||
//
|
||||
// rtype must be kept in sync with reflect/type.go:/^type._type.
|
||||
type rtype struct {
|
||||
_ uintptr
|
||||
_ uintptr // number of bytes in the type that can contain pointers
|
||||
_ uint32 // hash of type; avoids computation in hash tables
|
||||
_ uint8 // extra type information flags
|
||||
_ uint8 // alignment of variable with this type
|
||||
_ uint8 // alignment of struct field with this type
|
||||
kind uint8 // enumeration for C
|
||||
_ uintptr // algorithm table
|
||||
_ uintptr // garbage collection data
|
||||
_ int32 // string form
|
||||
_ int32 // type for pointer to this type, may be zero
|
||||
}
|
||||
|
||||
// ifaceIndir reports whether t is stored indirectly in an interface value.
|
||||
func ifaceDirect(t *rtype) bool {
|
||||
return t.kind&kindDirectIface != 0
|
||||
}
|
||||
|
||||
func directType(rt *reflect.Type) bool {
|
||||
return ifaceDirect((*rtype)((*[2]unsafe.Pointer)(unsafe.Pointer(rt))[1]))
|
||||
}
|
||||
|
||||
type refVal struct {
|
||||
_ unsafe.Pointer
|
||||
ptr unsafe.Pointer
|
||||
flag flag
|
||||
}
|
||||
|
||||
type flag uintptr
|
||||
|
||||
//go:linkname flagIndir reflect.flagIndir
|
||||
const flagIndir flag = 1 << 7
|
||||
|
||||
func getUnsafePointer(rv *reflect.Value) unsafe.Pointer {
|
||||
vv := (*refVal)(unsafe.Pointer(rv))
|
||||
if vv.flag&flagIndir == 0 {
|
||||
return unsafe.Pointer(&vv.ptr)
|
||||
} else {
|
||||
return vv.ptr
|
||||
}
|
||||
}
|
185
internal/gotiny/utils.go
Normal file
185
internal/gotiny/utils.go
Normal file
@ -0,0 +1,185 @@
|
||||
package gotiny
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/gob"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
ptr1Size = 4 << (^uintptr(0) >> 63) // unsafe.Sizeof(uintptr(0)) but an ideal const
|
||||
)
|
||||
|
||||
func float64ToUint64(v unsafe.Pointer) uint64 {
|
||||
return reverse64Byte(*(*uint64)(v))
|
||||
}
|
||||
|
||||
func uint64ToFloat64(u uint64) float64 {
|
||||
u = reverse64Byte(u)
|
||||
return *((*float64)(unsafe.Pointer(&u)))
|
||||
}
|
||||
|
||||
func reverse64Byte(u uint64) uint64 {
|
||||
u = (u << 32) | (u >> 32)
|
||||
u = ((u << 16) & 0xFFFF0000FFFF0000) | ((u >> 16) & 0xFFFF0000FFFF)
|
||||
u = ((u << 8) & 0xFF00FF00FF00FF00) | ((u >> 8) & 0xFF00FF00FF00FF)
|
||||
return u
|
||||
}
|
||||
|
||||
func float32ToUint32(v unsafe.Pointer) uint32 {
|
||||
return reverse32Byte(*(*uint32)(v))
|
||||
}
|
||||
|
||||
func uint32ToFloat32(u uint32) float32 {
|
||||
u = reverse32Byte(u)
|
||||
return *((*float32)(unsafe.Pointer(&u)))
|
||||
}
|
||||
|
||||
func reverse32Byte(u uint32) uint32 {
|
||||
u = (u << 16) | (u >> 16)
|
||||
return ((u << 8) & 0xFF00FF00) | ((u >> 8) & 0xFF00FF)
|
||||
}
|
||||
|
||||
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
|
||||
// uint 9 7 5 3 1 0 2 4 6 8 10 12
|
||||
func int64ToUint64(v int64) uint64 {
|
||||
return uint64((v << 1) ^ (v >> 63))
|
||||
}
|
||||
|
||||
// uint 9 7 5 3 1 0 2 4 6 8 10 12
|
||||
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
|
||||
func uint64ToInt64(u uint64) int64 {
|
||||
v := int64(u)
|
||||
return (-(v & 1)) ^ (v>>1)&0x7FFFFFFFFFFFFFFF
|
||||
}
|
||||
|
||||
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
|
||||
// uint 9 7 5 3 1 0 2 4 6 8 10 12
|
||||
func int32ToUint32(v int32) uint32 {
|
||||
return uint32((v << 1) ^ (v >> 31))
|
||||
}
|
||||
|
||||
// uint 9 7 5 3 1 0 2 4 6 8 10 12
|
||||
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
|
||||
func uint32ToInt32(u uint32) int32 {
|
||||
v := int32(u)
|
||||
return (-(v & 1)) ^ (v>>1)&0x7FFFFFFF
|
||||
}
|
||||
|
||||
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
|
||||
// uint 9 7 5 3 1 0 2 4 6 8 10 12
|
||||
func int16ToUint16(v int16) uint16 {
|
||||
return uint16((v << 1) ^ (v >> 15))
|
||||
}
|
||||
|
||||
// uint 9 7 5 3 1 0 2 4 6 8 10 12
|
||||
// int -5 -4 -3 -2 -1 0 1 2 3 4 5 6
|
||||
func uint16ToInt16(u uint16) int16 {
|
||||
v := int16(u)
|
||||
return (-(v & 1)) ^ (v>>1)&0x7FFF
|
||||
}
|
||||
|
||||
func isNil(p unsafe.Pointer) bool {
|
||||
return *(*unsafe.Pointer)(p) == nil
|
||||
}
|
||||
|
||||
type gobInter interface {
|
||||
gob.GobEncoder
|
||||
gob.GobDecoder
|
||||
}
|
||||
|
||||
type binInter interface {
|
||||
encoding.BinaryMarshaler
|
||||
encoding.BinaryUnmarshaler
|
||||
}
|
||||
|
||||
// 只应该由指针来实现该接口
|
||||
type GoTinySerializer interface {
|
||||
// 编码方法,将对象的序列化结果append到入参数并返回,方法不应该修改入参数值原有的值
|
||||
GotinyEncode([]byte) []byte
|
||||
// 解码方法,将入参解码到对象里并返回使用的长度。方法从入参的第0个字节开始使用,并且不应该修改入参中的任何数据
|
||||
GotinyDecode([]byte) int
|
||||
}
|
||||
|
||||
func implementOtherSerializer(rt reflect.Type) (encEng encEng, decEng decEng) {
|
||||
rtNil := reflect.Zero(reflect.PtrTo(rt)).Interface()
|
||||
if _, ok := rtNil.(GoTinySerializer); ok {
|
||||
encEng = func(e *Encoder, p unsafe.Pointer) {
|
||||
e.buf = reflect.NewAt(rt, p).Interface().(GoTinySerializer).GotinyEncode(e.buf)
|
||||
}
|
||||
decEng = func(d *Decoder, p unsafe.Pointer) {
|
||||
d.index += reflect.NewAt(rt, p).Interface().(GoTinySerializer).GotinyDecode(d.buf[d.index:])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := rtNil.(binInter); ok {
|
||||
encEng = func(e *Encoder, p unsafe.Pointer) {
|
||||
buf, err := reflect.NewAt(rt, p).Interface().(encoding.BinaryMarshaler).MarshalBinary()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
e.encLength(len(buf))
|
||||
e.buf = append(e.buf, buf...)
|
||||
}
|
||||
|
||||
decEng = func(d *Decoder, p unsafe.Pointer) {
|
||||
length := d.decLength()
|
||||
start := d.index
|
||||
d.index += length
|
||||
if err := reflect.NewAt(rt, p).Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(d.buf[start:d.index]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := rtNil.(gobInter); ok {
|
||||
encEng = func(e *Encoder, p unsafe.Pointer) {
|
||||
buf, err := reflect.NewAt(rt, p).Interface().(gob.GobEncoder).GobEncode()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
e.encLength(len(buf))
|
||||
e.buf = append(e.buf, buf...)
|
||||
}
|
||||
decEng = func(d *Decoder, p unsafe.Pointer) {
|
||||
length := d.decLength()
|
||||
start := d.index
|
||||
d.index += length
|
||||
if err := reflect.NewAt(rt, p).Interface().(gob.GobDecoder).GobDecode(d.buf[start:d.index]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// rt.kind is reflect.struct
|
||||
func getFieldType(rt reflect.Type, baseOff uintptr) (fields []reflect.Type, offs []uintptr) {
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
field := rt.Field(i)
|
||||
if ignoreField(field) {
|
||||
continue
|
||||
}
|
||||
ft := field.Type
|
||||
if ft.Kind() == reflect.Struct {
|
||||
if _, engine := implementOtherSerializer(ft); engine == nil {
|
||||
fFields, fOffs := getFieldType(ft, field.Offset+baseOff)
|
||||
fields = append(fields, fFields...)
|
||||
offs = append(offs, fOffs...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
fields = append(fields, ft)
|
||||
offs = append(offs, field.Offset+baseOff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ignoreField(field reflect.StructField) bool {
|
||||
tinyTag, ok := field.Tag.Lookup("gotiny")
|
||||
return ok && strings.TrimSpace(tinyTag) == "-"
|
||||
}
|
90
internal/memory/memory.go
Normal file
90
internal/memory/memory.go
Normal file
@ -0,0 +1,90 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
sync.RWMutex
|
||||
data map[string]item // data
|
||||
ts uint64 // timestamp
|
||||
}
|
||||
|
||||
type item struct {
|
||||
v interface{} // val
|
||||
e uint64 // exp
|
||||
}
|
||||
|
||||
func New() *Storage {
|
||||
store := &Storage{
|
||||
data: make(map[string]item),
|
||||
ts: uint64(time.Now().Unix()),
|
||||
}
|
||||
go store.gc(10 * time.Millisecond)
|
||||
go store.updater(1 * time.Second)
|
||||
return store
|
||||
}
|
||||
|
||||
// Get value by key
|
||||
func (s *Storage) Get(key string) interface{} {
|
||||
s.RLock()
|
||||
v, ok := s.data[key]
|
||||
s.RUnlock()
|
||||
if !ok || v.e != 0 && v.e <= atomic.LoadUint64(&s.ts) {
|
||||
return nil
|
||||
}
|
||||
return v.v
|
||||
}
|
||||
|
||||
// Set key with value
|
||||
func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
|
||||
var exp uint64
|
||||
if ttl > 0 {
|
||||
exp = uint64(ttl.Seconds()) + atomic.LoadUint64(&s.ts)
|
||||
}
|
||||
s.Lock()
|
||||
s.data[key] = item{val, exp}
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Delete key by key
|
||||
func (s *Storage) Delete(key string) {
|
||||
s.Lock()
|
||||
delete(s.data, key)
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Reset all keys
|
||||
func (s *Storage) Reset() {
|
||||
s.Lock()
|
||||
s.data = make(map[string]item)
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
func (s *Storage) updater(sleep time.Duration) {
|
||||
for {
|
||||
time.Sleep(sleep)
|
||||
atomic.StoreUint64(&s.ts, uint64(time.Now().Unix()))
|
||||
}
|
||||
}
|
||||
func (s *Storage) gc(sleep time.Duration) {
|
||||
expired := []string{}
|
||||
for {
|
||||
time.Sleep(sleep)
|
||||
expired = expired[:0]
|
||||
s.RLock()
|
||||
for key, v := range s.data {
|
||||
if v.e != 0 && v.e <= atomic.LoadUint64(&s.ts) {
|
||||
expired = append(expired, key)
|
||||
}
|
||||
}
|
||||
s.RUnlock()
|
||||
s.Lock()
|
||||
for i := range expired {
|
||||
delete(s.data, expired[i])
|
||||
}
|
||||
s.Unlock()
|
||||
}
|
||||
}
|
81
internal/memory/memory_test.go
Normal file
81
internal/memory/memory_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_Memory -v -race
|
||||
|
||||
func Test_Memory(t *testing.T) {
|
||||
var store = New()
|
||||
var (
|
||||
key = "john"
|
||||
val interface{} = []byte("doe")
|
||||
exp = 1 * time.Second
|
||||
)
|
||||
|
||||
store.Set(key, val, 0)
|
||||
store.Set(key, val, 0)
|
||||
|
||||
result := store.Get(key)
|
||||
utils.AssertEqual(t, val, result)
|
||||
|
||||
result = store.Get("empty")
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
store.Set(key, val, exp)
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
result = store.Get(key)
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
store.Set(key, val, 0)
|
||||
result = store.Get(key)
|
||||
utils.AssertEqual(t, val, result)
|
||||
|
||||
store.Delete(key)
|
||||
result = store.Get(key)
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
store.Set("john", val, 0)
|
||||
store.Set("doe", val, 0)
|
||||
store.Reset()
|
||||
|
||||
result = store.Get("john")
|
||||
utils.AssertEqual(t, nil, result)
|
||||
|
||||
result = store.Get("doe")
|
||||
utils.AssertEqual(t, nil, result)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Memory -benchmem -count=4
|
||||
func Benchmark_Memory(b *testing.B) {
|
||||
keyLength := 1000
|
||||
keys := make([]string, keyLength)
|
||||
for i := 0; i < keyLength; i++ {
|
||||
keys[i] = utils.UUID()
|
||||
}
|
||||
value := []string{"some", "random", "value"}
|
||||
|
||||
ttl := 2 * time.Second
|
||||
b.Run("fiber_memory", func(b *testing.B) {
|
||||
d := New()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
for _, key := range keys {
|
||||
d.Set(key, value, ttl)
|
||||
}
|
||||
for _, key := range keys {
|
||||
_ = d.Get(key)
|
||||
}
|
||||
for _, key := range keys {
|
||||
d.Delete(key)
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
package memory
|
||||
|
||||
import "time"
|
||||
|
||||
// Config defines the config for storage.
|
||||
type Config struct {
|
||||
// Time before deleting expired keys
|
||||
//
|
||||
// Default is 10 * time.Second
|
||||
GCInterval time.Duration
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
GCInterval: 10 * time.Second,
|
||||
}
|
||||
|
||||
// configDefault is a helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.GCInterval.Seconds()) <= 0 {
|
||||
cfg.GCInterval = ConfigDefault.GCInterval
|
||||
}
|
||||
return cfg
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -14,23 +13,17 @@ type Storage struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Common storage errors
|
||||
var ErrNotExist = errors.New("key does not exist")
|
||||
|
||||
type entry struct {
|
||||
data []byte
|
||||
expiry int64
|
||||
}
|
||||
|
||||
// New creates a new memory storage
|
||||
func New(config ...Config) *Storage {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
func New() *Storage {
|
||||
// Create storage
|
||||
store := &Storage{
|
||||
db: make(map[string]entry),
|
||||
gcInterval: cfg.GCInterval,
|
||||
gcInterval: 10 * time.Second,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
@ -43,13 +36,13 @@ func New(config ...Config) *Storage {
|
||||
// Get value by key
|
||||
func (s *Storage) Get(key string) ([]byte, error) {
|
||||
if len(key) <= 0 {
|
||||
return nil, ErrNotExist
|
||||
return nil, nil
|
||||
}
|
||||
s.mux.RLock()
|
||||
v, ok := s.db[key]
|
||||
s.mux.RUnlock()
|
||||
if !ok || v.expiry != 0 && v.expiry <= time.Now().Unix() {
|
||||
return nil, ErrNotExist
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return v.data, nil
|
||||
|
9
internal/uuid/CONTRIBUTORS
Normal file
9
internal/uuid/CONTRIBUTORS
Normal file
@ -0,0 +1,9 @@
|
||||
Paul Borman <borman@google.com>
|
||||
bmatsuo
|
||||
shawnps
|
||||
theory
|
||||
jboverfelt
|
||||
dsymonds
|
||||
cd1
|
||||
wallclockbuilder
|
||||
dansouza
|
27
internal/uuid/LICENSE
Normal file
27
internal/uuid/LICENSE
Normal file
@ -0,0 +1,27 @@
|
||||
Copyright (c) 2009,2014 Google Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
80
internal/uuid/dce.go
Normal file
80
internal/uuid/dce.go
Normal file
@ -0,0 +1,80 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// A Domain represents a Version 2 domain
|
||||
type Domain byte
|
||||
|
||||
// Domain constants for DCE Security (Version 2) UUIDs.
|
||||
const (
|
||||
Person = Domain(0)
|
||||
Group = Domain(1)
|
||||
Org = Domain(2)
|
||||
)
|
||||
|
||||
// NewDCESecurity returns a DCE Security (Version 2) UUID.
|
||||
//
|
||||
// The domain should be one of Person, Group or Org.
|
||||
// On a POSIX system the id should be the users UID for the Person
|
||||
// domain and the users GID for the Group. The meaning of id for
|
||||
// the domain Org or on non-POSIX systems is site defined.
|
||||
//
|
||||
// For a given domain/id pair the same token may be returned for up to
|
||||
// 7 minutes and 10 seconds.
|
||||
func NewDCESecurity(domain Domain, id uint32) (UUID, error) {
|
||||
uuid, err := NewUUID()
|
||||
if err == nil {
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
|
||||
uuid[9] = byte(domain)
|
||||
binary.BigEndian.PutUint32(uuid[0:], id)
|
||||
}
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
|
||||
// domain with the id returned by os.Getuid.
|
||||
//
|
||||
// NewDCESecurity(Person, uint32(os.Getuid()))
|
||||
func NewDCEPerson() (UUID, error) {
|
||||
return NewDCESecurity(Person, uint32(os.Getuid()))
|
||||
}
|
||||
|
||||
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
|
||||
// domain with the id returned by os.Getgid.
|
||||
//
|
||||
// NewDCESecurity(Group, uint32(os.Getgid()))
|
||||
func NewDCEGroup() (UUID, error) {
|
||||
return NewDCESecurity(Group, uint32(os.Getgid()))
|
||||
}
|
||||
|
||||
// Domain returns the domain for a Version 2 UUID. Domains are only defined
|
||||
// for Version 2 UUIDs.
|
||||
func (uuid UUID) Domain() Domain {
|
||||
return Domain(uuid[9])
|
||||
}
|
||||
|
||||
// ID returns the id for a Version 2 UUID. IDs are only defined for Version 2
|
||||
// UUIDs.
|
||||
func (uuid UUID) ID() uint32 {
|
||||
return binary.BigEndian.Uint32(uuid[0:4])
|
||||
}
|
||||
|
||||
func (d Domain) String() string {
|
||||
switch d {
|
||||
case Person:
|
||||
return "Person"
|
||||
case Group:
|
||||
return "Group"
|
||||
case Org:
|
||||
return "Org"
|
||||
}
|
||||
return fmt.Sprintf("Domain%d", int(d))
|
||||
}
|
12
internal/uuid/doc.go
Normal file
12
internal/uuid/doc.go
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package uuid generates and inspects UUIDs.
|
||||
//
|
||||
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security
|
||||
// Services.
|
||||
//
|
||||
// A UUID is a 16 byte (128 bit) array. UUIDs may be used as keys to
|
||||
// maps or compared directly.
|
||||
package uuid
|
53
internal/uuid/hash.go
Normal file
53
internal/uuid/hash.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// Well known namespace IDs and UUIDs
|
||||
var (
|
||||
NameSpaceDNS = Must(Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
|
||||
NameSpaceURL = Must(Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
|
||||
NameSpaceOID = Must(Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
|
||||
NameSpaceX500 = Must(Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
|
||||
Nil UUID // empty UUID, all zeros
|
||||
)
|
||||
|
||||
// NewHash returns a new UUID derived from the hash of space concatenated with
|
||||
// data generated by h. The hash should be at least 16 byte in length. The
|
||||
// first 16 bytes of the hash are used to form the UUID. The version of the
|
||||
// UUID will be the lower 4 bits of version. NewHash is used to implement
|
||||
// NewMD5 and NewSHA1.
|
||||
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
|
||||
h.Reset()
|
||||
h.Write(space[:])
|
||||
h.Write(data)
|
||||
s := h.Sum(nil)
|
||||
var uuid UUID
|
||||
copy(uuid[:], s)
|
||||
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
|
||||
return uuid
|
||||
}
|
||||
|
||||
// NewMD5 returns a new MD5 (Version 3) UUID based on the
|
||||
// supplied name space and data. It is the same as calling:
|
||||
//
|
||||
// NewHash(md5.New(), space, data, 3)
|
||||
func NewMD5(space UUID, data []byte) UUID {
|
||||
return NewHash(md5.New(), space, data, 3)
|
||||
}
|
||||
|
||||
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
|
||||
// supplied name space and data. It is the same as calling:
|
||||
//
|
||||
// NewHash(sha1.New(), space, data, 5)
|
||||
func NewSHA1(space UUID, data []byte) UUID {
|
||||
return NewHash(sha1.New(), space, data, 5)
|
||||
}
|
38
internal/uuid/marshal.go
Normal file
38
internal/uuid/marshal.go
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import "fmt"
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (uuid UUID) MarshalText() ([]byte, error) {
|
||||
var js [36]byte
|
||||
encodeHex(js[:], uuid)
|
||||
return js[:], nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (uuid *UUID) UnmarshalText(data []byte) error {
|
||||
id, err := ParseBytes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*uuid = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalBinary implements encoding.BinaryMarshaler.
|
||||
func (uuid UUID) MarshalBinary() ([]byte, error) {
|
||||
return uuid[:], nil
|
||||
}
|
||||
|
||||
// UnmarshalBinary implements encoding.BinaryUnmarshaler.
|
||||
func (uuid *UUID) UnmarshalBinary(data []byte) error {
|
||||
if len(data) != 16 {
|
||||
return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
|
||||
}
|
||||
copy(uuid[:], data)
|
||||
return nil
|
||||
}
|
90
internal/uuid/node.go
Normal file
90
internal/uuid/node.go
Normal file
@ -0,0 +1,90 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
nodeMu sync.Mutex
|
||||
ifname string // name of interface being used
|
||||
nodeID [6]byte // hardware for version 1 UUIDs
|
||||
zeroID [6]byte // nodeID with only 0's
|
||||
)
|
||||
|
||||
// NodeInterface returns the name of the interface from which the NodeID was
|
||||
// derived. The interface "user" is returned if the NodeID was set by
|
||||
// SetNodeID.
|
||||
func NodeInterface() string {
|
||||
defer nodeMu.Unlock()
|
||||
nodeMu.Lock()
|
||||
return ifname
|
||||
}
|
||||
|
||||
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
|
||||
// If name is "" then the first usable interface found will be used or a random
|
||||
// Node ID will be generated. If a named interface cannot be found then false
|
||||
// is returned.
|
||||
//
|
||||
// SetNodeInterface never fails when name is "".
|
||||
func SetNodeInterface(name string) bool {
|
||||
defer nodeMu.Unlock()
|
||||
nodeMu.Lock()
|
||||
return setNodeInterface(name)
|
||||
}
|
||||
|
||||
func setNodeInterface(name string) bool {
|
||||
iname, addr := getHardwareInterface(name) // null implementation for js
|
||||
if iname != "" && addr != nil {
|
||||
ifname = iname
|
||||
copy(nodeID[:], addr)
|
||||
return true
|
||||
}
|
||||
|
||||
// We found no interfaces with a valid hardware address. If name
|
||||
// does not specify a specific interface generate a random Node ID
|
||||
// (section 4.1.6)
|
||||
if name == "" {
|
||||
ifname = "random"
|
||||
randomBits(nodeID[:])
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
|
||||
// if not already set.
|
||||
func NodeID() []byte {
|
||||
defer nodeMu.Unlock()
|
||||
nodeMu.Lock()
|
||||
if nodeID == zeroID {
|
||||
setNodeInterface("")
|
||||
}
|
||||
nid := nodeID
|
||||
return nid[:]
|
||||
}
|
||||
|
||||
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
|
||||
// of id are used. If id is less than 6 bytes then false is returned and the
|
||||
// Node ID is not set.
|
||||
func SetNodeID(id []byte) bool {
|
||||
if len(id) < 6 {
|
||||
return false
|
||||
}
|
||||
defer nodeMu.Unlock()
|
||||
nodeMu.Lock()
|
||||
copy(nodeID[:], id)
|
||||
ifname = "user"
|
||||
return true
|
||||
}
|
||||
|
||||
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
|
||||
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
|
||||
func (uuid UUID) NodeID() []byte {
|
||||
var node [6]byte
|
||||
copy(node[:], uuid[10:])
|
||||
return node[:]
|
||||
}
|
12
internal/uuid/node_js.go
Normal file
12
internal/uuid/node_js.go
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright 2017 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build js
|
||||
|
||||
package uuid
|
||||
|
||||
// getHardwareInterface returns nil values for the JS version of the code.
|
||||
// This remvoves the "net" dependency, because it is not used in the browser.
|
||||
// Using the "net" library inflates the size of the transpiled JS code by 673k bytes.
|
||||
func getHardwareInterface(name string) (string, []byte) { return "", nil }
|
33
internal/uuid/node_net.go
Normal file
33
internal/uuid/node_net.go
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright 2017 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !js
|
||||
|
||||
package uuid
|
||||
|
||||
import "net"
|
||||
|
||||
var interfaces []net.Interface // cached list of interfaces
|
||||
|
||||
// getHardwareInterface returns the name and hardware address of interface name.
|
||||
// If name is "" then the name and hardware address of one of the system's
|
||||
// interfaces is returned. If no interfaces are found (name does not exist or
|
||||
// there are no interfaces) then "", nil is returned.
|
||||
//
|
||||
// Only addresses of at least 6 bytes are returned.
|
||||
func getHardwareInterface(name string) (string, []byte) {
|
||||
if interfaces == nil {
|
||||
var err error
|
||||
interfaces, err = net.Interfaces()
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
for _, ifs := range interfaces {
|
||||
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
|
||||
return ifs.Name, ifs.HardwareAddr
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
59
internal/uuid/sql.go
Normal file
59
internal/uuid/sql.go
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Scan implements sql.Scanner so UUIDs can be read from databases transparently
|
||||
// Currently, database types that map to string and []byte are supported. Please
|
||||
// consult database-specific driver documentation for matching types.
|
||||
func (uuid *UUID) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
|
||||
case string:
|
||||
// if an empty UUID comes from a table, we return a null UUID
|
||||
if src == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// see Parse for required string format
|
||||
u, err := Parse(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Scan: %v", err)
|
||||
}
|
||||
|
||||
*uuid = u
|
||||
|
||||
case []byte:
|
||||
// if an empty UUID comes from a table, we return a null UUID
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// assumes a simple slice of bytes if 16 bytes
|
||||
// otherwise attempts to parse
|
||||
if len(src) != 16 {
|
||||
return uuid.Scan(string(src))
|
||||
}
|
||||
copy((*uuid)[:], src)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Scan: unable to scan type %T into UUID", src)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements sql.Valuer so that UUIDs can be written to databases
|
||||
// transparently. Currently, UUIDs map to strings. Please consult
|
||||
// database-specific driver documentation for matching types.
|
||||
func (uuid UUID) Value() (driver.Value, error) {
|
||||
return uuid.String(), nil
|
||||
}
|
123
internal/uuid/time.go
Normal file
123
internal/uuid/time.go
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
|
||||
// 1582.
|
||||
type Time int64
|
||||
|
||||
const (
|
||||
lillian = 2299160 // Julian day of 15 Oct 1582
|
||||
unix = 2440587 // Julian day of 1 Jan 1970
|
||||
epoch = unix - lillian // Days between epochs
|
||||
g1582 = epoch * 86400 // seconds between epochs
|
||||
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
|
||||
)
|
||||
|
||||
var (
|
||||
timeMu sync.Mutex
|
||||
lasttime uint64 // last time we returned
|
||||
clockSeq uint16 // clock sequence for this run
|
||||
|
||||
timeNow = time.Now // for testing
|
||||
)
|
||||
|
||||
// UnixTime converts t the number of seconds and nanoseconds using the Unix
|
||||
// epoch of 1 Jan 1970.
|
||||
func (t Time) UnixTime() (sec, nsec int64) {
|
||||
sec = int64(t - g1582ns100)
|
||||
nsec = (sec % 10000000) * 100
|
||||
sec /= 10000000
|
||||
return sec, nsec
|
||||
}
|
||||
|
||||
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
|
||||
// clock sequence as well as adjusting the clock sequence as needed. An error
|
||||
// is returned if the current time cannot be determined.
|
||||
func GetTime() (Time, uint16, error) {
|
||||
defer timeMu.Unlock()
|
||||
timeMu.Lock()
|
||||
return getTime()
|
||||
}
|
||||
|
||||
func getTime() (Time, uint16, error) {
|
||||
t := timeNow()
|
||||
|
||||
// If we don't have a clock sequence already, set one.
|
||||
if clockSeq == 0 {
|
||||
setClockSequence(-1)
|
||||
}
|
||||
now := uint64(t.UnixNano()/100) + g1582ns100
|
||||
|
||||
// If time has gone backwards with this clock sequence then we
|
||||
// increment the clock sequence
|
||||
if now <= lasttime {
|
||||
clockSeq = ((clockSeq + 1) & 0x3fff) | 0x8000
|
||||
}
|
||||
lasttime = now
|
||||
return Time(now), clockSeq, nil
|
||||
}
|
||||
|
||||
// ClockSequence returns the current clock sequence, generating one if not
|
||||
// already set. The clock sequence is only used for Version 1 UUIDs.
|
||||
//
|
||||
// The uuid package does not use global static storage for the clock sequence or
|
||||
// the last time a UUID was generated. Unless SetClockSequence is used, a new
|
||||
// random clock sequence is generated the first time a clock sequence is
|
||||
// requested by ClockSequence, GetTime, or NewUUID. (section 4.2.1.1)
|
||||
func ClockSequence() int {
|
||||
defer timeMu.Unlock()
|
||||
timeMu.Lock()
|
||||
return clockSequence()
|
||||
}
|
||||
|
||||
func clockSequence() int {
|
||||
if clockSeq == 0 {
|
||||
setClockSequence(-1)
|
||||
}
|
||||
return int(clockSeq & 0x3fff)
|
||||
}
|
||||
|
||||
// SetClockSequence sets the clock sequence to the lower 14 bits of seq. Setting to
|
||||
// -1 causes a new sequence to be generated.
|
||||
func SetClockSequence(seq int) {
|
||||
defer timeMu.Unlock()
|
||||
timeMu.Lock()
|
||||
setClockSequence(seq)
|
||||
}
|
||||
|
||||
func setClockSequence(seq int) {
|
||||
if seq == -1 {
|
||||
var b [2]byte
|
||||
randomBits(b[:]) // clock sequence
|
||||
seq = int(b[0])<<8 | int(b[1])
|
||||
}
|
||||
oldSeq := clockSeq
|
||||
clockSeq = uint16(seq&0x3fff) | 0x8000 // Set our variant
|
||||
if oldSeq != clockSeq {
|
||||
lasttime = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
|
||||
// uuid. The time is only defined for version 1 and 2 UUIDs.
|
||||
func (uuid UUID) Time() Time {
|
||||
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
|
||||
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
|
||||
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
|
||||
return Time(time)
|
||||
}
|
||||
|
||||
// ClockSequence returns the clock sequence encoded in uuid.
|
||||
// The clock sequence is only well defined for version 1 and 2 UUIDs.
|
||||
func (uuid UUID) ClockSequence() int {
|
||||
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff
|
||||
}
|
43
internal/uuid/util.go
Normal file
43
internal/uuid/util.go
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// randomBits completely fills slice b with random data.
|
||||
func randomBits(b []byte) {
|
||||
if _, err := io.ReadFull(rander, b); err != nil {
|
||||
panic(err.Error()) // rand should never fail
|
||||
}
|
||||
}
|
||||
|
||||
// xvalues returns the value of a byte as a hexadecimal digit or 255.
|
||||
var xvalues = [256]byte{
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
|
||||
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
}
|
||||
|
||||
// xtob converts hex characters x1 and x2 into a byte.
|
||||
func xtob(x1, x2 byte) (byte, bool) {
|
||||
b1 := xvalues[x1]
|
||||
b2 := xvalues[x2]
|
||||
return (b1 << 4) | b2, b1 != 255 && b2 != 255
|
||||
}
|
245
internal/uuid/uuid.go
Normal file
245
internal/uuid/uuid.go
Normal file
@ -0,0 +1,245 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
|
||||
// 4122.
|
||||
type UUID [16]byte
|
||||
|
||||
// A Version represents a UUID's version.
|
||||
type Version byte
|
||||
|
||||
// A Variant represents a UUID's variant.
|
||||
type Variant byte
|
||||
|
||||
// Constants returned by Variant.
|
||||
const (
|
||||
Invalid = Variant(iota) // Invalid UUID
|
||||
RFC4122 // The variant specified in RFC4122
|
||||
Reserved // Reserved, NCS backward compatibility.
|
||||
Microsoft // Reserved, Microsoft Corporation backward compatibility.
|
||||
Future // Reserved for future definition.
|
||||
)
|
||||
|
||||
var rander = rand.Reader // random function
|
||||
|
||||
// Parse decodes s into a UUID or returns an error. Both the standard UUID
|
||||
// forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
|
||||
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the
|
||||
// Microsoft encoding {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} and the raw hex
|
||||
// encoding: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.
|
||||
func Parse(s string) (UUID, error) {
|
||||
var uuid UUID
|
||||
switch len(s) {
|
||||
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
case 36:
|
||||
|
||||
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
case 36 + 9:
|
||||
if strings.ToLower(s[:9]) != "urn:uuid:" {
|
||||
return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9])
|
||||
}
|
||||
s = s[9:]
|
||||
|
||||
// {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
|
||||
case 36 + 2:
|
||||
s = s[1:]
|
||||
|
||||
// xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
case 32:
|
||||
var ok bool
|
||||
for i := range uuid {
|
||||
uuid[i], ok = xtob(s[i*2], s[i*2+1])
|
||||
if !ok {
|
||||
return uuid, errors.New("invalid UUID format")
|
||||
}
|
||||
}
|
||||
return uuid, nil
|
||||
default:
|
||||
return uuid, fmt.Errorf("invalid UUID length: %d", len(s))
|
||||
}
|
||||
// s is now at least 36 bytes long
|
||||
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||
return uuid, errors.New("invalid UUID format")
|
||||
}
|
||||
for i, x := range [16]int{
|
||||
0, 2, 4, 6,
|
||||
9, 11,
|
||||
14, 16,
|
||||
19, 21,
|
||||
24, 26, 28, 30, 32, 34} {
|
||||
v, ok := xtob(s[x], s[x+1])
|
||||
if !ok {
|
||||
return uuid, errors.New("invalid UUID format")
|
||||
}
|
||||
uuid[i] = v
|
||||
}
|
||||
return uuid, nil
|
||||
}
|
||||
|
||||
// ParseBytes is like Parse, except it parses a byte slice instead of a string.
|
||||
func ParseBytes(b []byte) (UUID, error) {
|
||||
var uuid UUID
|
||||
switch len(b) {
|
||||
case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) {
|
||||
return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9])
|
||||
}
|
||||
b = b[9:]
|
||||
case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
|
||||
b = b[1:]
|
||||
case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
var ok bool
|
||||
for i := 0; i < 32; i += 2 {
|
||||
uuid[i/2], ok = xtob(b[i], b[i+1])
|
||||
if !ok {
|
||||
return uuid, errors.New("invalid UUID format")
|
||||
}
|
||||
}
|
||||
return uuid, nil
|
||||
default:
|
||||
return uuid, fmt.Errorf("invalid UUID length: %d", len(b))
|
||||
}
|
||||
// s is now at least 36 bytes long
|
||||
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
|
||||
return uuid, errors.New("invalid UUID format")
|
||||
}
|
||||
for i, x := range [16]int{
|
||||
0, 2, 4, 6,
|
||||
9, 11,
|
||||
14, 16,
|
||||
19, 21,
|
||||
24, 26, 28, 30, 32, 34} {
|
||||
v, ok := xtob(b[x], b[x+1])
|
||||
if !ok {
|
||||
return uuid, errors.New("invalid UUID format")
|
||||
}
|
||||
uuid[i] = v
|
||||
}
|
||||
return uuid, nil
|
||||
}
|
||||
|
||||
// MustParse is like Parse but panics if the string cannot be parsed.
|
||||
// It simplifies safe initialization of global variables holding compiled UUIDs.
|
||||
func MustParse(s string) UUID {
|
||||
uuid, err := Parse(s)
|
||||
if err != nil {
|
||||
panic(`uuid: Parse(` + s + `): ` + err.Error())
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
|
||||
// FromBytes creates a new UUID from a byte slice. Returns an error if the slice
|
||||
// does not have a length of 16. The bytes are copied from the slice.
|
||||
func FromBytes(b []byte) (uuid UUID, err error) {
|
||||
err = uuid.UnmarshalBinary(b)
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
// Must returns uuid if err is nil and panics otherwise.
|
||||
func Must(uuid UUID, err error) UUID {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
|
||||
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
// , or "" if uuid is invalid.
|
||||
func (uuid UUID) String() string {
|
||||
var buf [36]byte
|
||||
encodeHex(buf[:], uuid)
|
||||
return string(buf[:])
|
||||
}
|
||||
|
||||
// URN returns the RFC 2141 URN form of uuid,
|
||||
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
|
||||
func (uuid UUID) URN() string {
|
||||
var buf [36 + 9]byte
|
||||
copy(buf[:], "urn:uuid:")
|
||||
encodeHex(buf[9:], uuid)
|
||||
return string(buf[:])
|
||||
}
|
||||
|
||||
func encodeHex(dst []byte, uuid UUID) {
|
||||
hex.Encode(dst, uuid[:4])
|
||||
dst[8] = '-'
|
||||
hex.Encode(dst[9:13], uuid[4:6])
|
||||
dst[13] = '-'
|
||||
hex.Encode(dst[14:18], uuid[6:8])
|
||||
dst[18] = '-'
|
||||
hex.Encode(dst[19:23], uuid[8:10])
|
||||
dst[23] = '-'
|
||||
hex.Encode(dst[24:], uuid[10:])
|
||||
}
|
||||
|
||||
// Variant returns the variant encoded in uuid.
|
||||
func (uuid UUID) Variant() Variant {
|
||||
switch {
|
||||
case (uuid[8] & 0xc0) == 0x80:
|
||||
return RFC4122
|
||||
case (uuid[8] & 0xe0) == 0xc0:
|
||||
return Microsoft
|
||||
case (uuid[8] & 0xe0) == 0xe0:
|
||||
return Future
|
||||
default:
|
||||
return Reserved
|
||||
}
|
||||
}
|
||||
|
||||
// Version returns the version of uuid.
|
||||
func (uuid UUID) Version() Version {
|
||||
return Version(uuid[6] >> 4)
|
||||
}
|
||||
|
||||
func (v Version) String() string {
|
||||
if v > 15 {
|
||||
return fmt.Sprintf("BAD_VERSION_%d", v)
|
||||
}
|
||||
return fmt.Sprintf("VERSION_%d", v)
|
||||
}
|
||||
|
||||
func (v Variant) String() string {
|
||||
switch v {
|
||||
case RFC4122:
|
||||
return "RFC4122"
|
||||
case Reserved:
|
||||
return "Reserved"
|
||||
case Microsoft:
|
||||
return "Microsoft"
|
||||
case Future:
|
||||
return "Future"
|
||||
case Invalid:
|
||||
return "Invalid"
|
||||
}
|
||||
return fmt.Sprintf("BadVariant%d", int(v))
|
||||
}
|
||||
|
||||
// SetRand sets the random number generator to r, which implements io.Reader.
|
||||
// If r.Read returns an error when the package requests random data then
|
||||
// a panic will be issued.
|
||||
//
|
||||
// Calling SetRand with nil sets the random number generator to the default
|
||||
// generator.
|
||||
func SetRand(r io.Reader) {
|
||||
if r == nil {
|
||||
rander = rand.Reader
|
||||
return
|
||||
}
|
||||
rander = r
|
||||
}
|
44
internal/uuid/version1.go
Normal file
44
internal/uuid/version1.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
|
||||
// sequence, and the current time. If the NodeID has not been set by SetNodeID
|
||||
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
|
||||
// be set NewUUID returns nil. If clock sequence has not been set by
|
||||
// SetClockSequence then it will be set automatically. If GetTime fails to
|
||||
// return the current NewUUID returns nil and an error.
|
||||
//
|
||||
// In most cases, New should be used.
|
||||
func NewUUID() (UUID, error) {
|
||||
var uuid UUID
|
||||
now, seq, err := GetTime()
|
||||
if err != nil {
|
||||
return uuid, err
|
||||
}
|
||||
|
||||
timeLow := uint32(now & 0xffffffff)
|
||||
timeMid := uint16((now >> 32) & 0xffff)
|
||||
timeHi := uint16((now >> 48) & 0x0fff)
|
||||
timeHi |= 0x1000 // Version 1
|
||||
|
||||
binary.BigEndian.PutUint32(uuid[0:], timeLow)
|
||||
binary.BigEndian.PutUint16(uuid[4:], timeMid)
|
||||
binary.BigEndian.PutUint16(uuid[6:], timeHi)
|
||||
binary.BigEndian.PutUint16(uuid[8:], seq)
|
||||
|
||||
nodeMu.Lock()
|
||||
if nodeID == zeroID {
|
||||
setNodeInterface("")
|
||||
}
|
||||
copy(uuid[10:], nodeID[:])
|
||||
nodeMu.Unlock()
|
||||
|
||||
return uuid, nil
|
||||
}
|
43
internal/uuid/version4.go
Normal file
43
internal/uuid/version4.go
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package uuid
|
||||
|
||||
import "io"
|
||||
|
||||
// New creates a new random UUID or panics. New is equivalent to
|
||||
// the expression
|
||||
//
|
||||
// uuid.Must(uuid.NewRandom())
|
||||
func New() UUID {
|
||||
return Must(NewRandom())
|
||||
}
|
||||
|
||||
// NewRandom returns a Random (Version 4) UUID.
|
||||
//
|
||||
// The strength of the UUIDs is based on the strength of the crypto/rand
|
||||
// package.
|
||||
//
|
||||
// A note about uniqueness derived from the UUID Wikipedia entry:
|
||||
//
|
||||
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
|
||||
// hit by a meteorite is estimated to be one chance in 17 billion, that
|
||||
// means the probability is about 0.00000000006 (6 × 10−11),
|
||||
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
|
||||
// year and having one duplicate.
|
||||
func NewRandom() (UUID, error) {
|
||||
return NewRandomFromReader(rander)
|
||||
}
|
||||
|
||||
// NewRandomFromReader returns a UUID based on bytes read from a given io.Reader.
|
||||
func NewRandomFromReader(r io.Reader) (UUID, error) {
|
||||
var uuid UUID
|
||||
_, err := io.ReadFull(r, uuid[:])
|
||||
if err != nil {
|
||||
return Nil, err
|
||||
}
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
|
||||
return uuid, nil
|
||||
}
|
7
middleware/cache/README.md
vendored
7
middleware/cache/README.md
vendored
@ -75,12 +75,12 @@ type Config struct {
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.Path()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Store fiber.Storage
|
||||
Storage fiber.Storage
|
||||
}
|
||||
```
|
||||
|
||||
@ -92,8 +92,9 @@ var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheControl: false,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.Path()
|
||||
},
|
||||
Storage: nil,
|
||||
}
|
||||
```
|
||||
|
149
middleware/cache/cache.go
vendored
149
middleware/cache/cache.go
vendored
@ -17,15 +17,21 @@ func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Nothing to cache
|
||||
if int(cfg.Expiration.Seconds()) < 0 {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// Cache settings
|
||||
mux = &sync.RWMutex{}
|
||||
timestamp = uint64(time.Now().Unix())
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
mux = &sync.RWMutex{}
|
||||
|
||||
// Default store logic (if no Store is provided)
|
||||
entries = make(map[string]entry)
|
||||
)
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
go func() {
|
||||
@ -35,30 +41,6 @@ func New(config ...Config) fiber.Handler {
|
||||
}
|
||||
}()
|
||||
|
||||
// Nothing to cache
|
||||
if int(cfg.Expiration.Seconds()) < 0 {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Remove expired entries
|
||||
if cfg.defaultStore {
|
||||
go func() {
|
||||
for {
|
||||
// GC the entries every 10 seconds
|
||||
time.Sleep(10 * time.Second)
|
||||
mux.Lock()
|
||||
for k := range entries {
|
||||
if atomic.LoadUint64(×tamp) >= entries[k].exp {
|
||||
delete(entries, k)
|
||||
}
|
||||
}
|
||||
mux.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
@ -72,74 +54,45 @@ func New(config ...Config) fiber.Handler {
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.Key(c)
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Create new entry
|
||||
var entry entry
|
||||
var entryBody []byte
|
||||
// Get entry from pool
|
||||
e := manager.get(key)
|
||||
|
||||
// Lock entry
|
||||
// Lock entry and unlock when finished
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
// Check if we need to use the default in-memory storage
|
||||
if cfg.defaultStore {
|
||||
entry = entries[key]
|
||||
|
||||
} else {
|
||||
// Load data from store
|
||||
storeEntry, err := cfg.Storage.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only decode if we found an entry
|
||||
if storeEntry != nil {
|
||||
// Decode bytes using msgp
|
||||
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get timestamp
|
||||
ts := atomic.LoadUint64(×tamp)
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if entry.exp == 0 {
|
||||
entry.exp = ts + expiration
|
||||
if e.exp == 0 {
|
||||
// Set expiration if entry does not exist
|
||||
e.exp = ts + expiration
|
||||
|
||||
} else if ts >= entry.exp {
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
// Use default memory storage
|
||||
if cfg.defaultStore {
|
||||
delete(entries, key)
|
||||
} else { // Use custom storage
|
||||
if err := cfg.Storage.Delete(key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cfg.Storage.Delete(key + "_body"); err != nil {
|
||||
return err
|
||||
}
|
||||
manager.delete(key)
|
||||
// External storage saves body data with different key
|
||||
if cfg.Storage != nil {
|
||||
manager.delete(key + "_body")
|
||||
}
|
||||
|
||||
} else {
|
||||
if cfg.defaultStore {
|
||||
c.Response().SetBodyRaw(entry.body)
|
||||
} else {
|
||||
c.Response().SetBodyRaw(entryBody)
|
||||
// Seperate body value to avoid msgp serialization
|
||||
// We can store raw bytes with Storage 👍
|
||||
if cfg.Storage != nil {
|
||||
e.body = manager.getRaw(key + "_body")
|
||||
}
|
||||
// Set response headers from cache
|
||||
c.Response().SetStatusCode(entry.status)
|
||||
c.Response().Header.SetContentTypeBytes(entry.cType)
|
||||
|
||||
c.Response().SetBodyRaw(e.body)
|
||||
c.Response().SetStatusCode(e.status)
|
||||
c.Response().Header.SetContentTypeBytes(e.ctype)
|
||||
if len(e.cencoding) > 0 {
|
||||
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
|
||||
}
|
||||
// Set Cache-Control header if enabled
|
||||
if cfg.CacheControl {
|
||||
maxAge := strconv.FormatUint(entry.exp-ts, 10)
|
||||
maxAge := strconv.FormatUint(e.exp-ts, 10)
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
|
||||
}
|
||||
|
||||
@ -153,31 +106,21 @@ func New(config ...Config) fiber.Handler {
|
||||
}
|
||||
|
||||
// Cache response
|
||||
entryBody = utils.SafeBytes(c.Response().Body())
|
||||
entry.status = c.Response().StatusCode()
|
||||
entry.cType = utils.SafeBytes(c.Response().Header.ContentType())
|
||||
|
||||
// Use default memory storage
|
||||
if cfg.defaultStore {
|
||||
entry.body = entryBody
|
||||
entries[key] = entry
|
||||
e.body = utils.SafeBytes(c.Response().Body())
|
||||
e.status = c.Response().StatusCode()
|
||||
e.ctype = utils.SafeBytes(c.Response().Header.ContentType())
|
||||
e.cencoding = utils.SafeBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
|
||||
|
||||
// For external Storage we store raw body seperated
|
||||
if cfg.Storage != nil {
|
||||
manager.setRaw(key+"_body", e.body, cfg.Expiration)
|
||||
// avoid body msgp encoding
|
||||
e.body = nil
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
manager.release(e)
|
||||
} else {
|
||||
// Use custom storage
|
||||
data, err := entry.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass bytes to Storage
|
||||
if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
// Store entry in memory
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
}
|
||||
|
||||
// Finish response
|
||||
|
125
middleware/cache/cache_test.go
vendored
125
middleware/cache/cache_test.go
vendored
@ -6,13 +6,12 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -93,54 +92,54 @@ func Test_Cache(t *testing.T) {
|
||||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
|
||||
// go test -run Test_Cache_Concurrency_Store -race -v
|
||||
func Test_Cache_Concurrency_Store(t *testing.T) {
|
||||
// Test concurrency using a custom store
|
||||
// // go test -run Test_Cache_Concurrency_Storage -race -v
|
||||
// func Test_Cache_Concurrency_Storage(t *testing.T) {
|
||||
// // Test concurrency using a custom store
|
||||
|
||||
app := fiber.New()
|
||||
// app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: &sync.RWMutex{}},
|
||||
}))
|
||||
// app.Use(New(Config{
|
||||
// Storage: memory.New(),
|
||||
// }))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
// app.Get("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendString("Hello tester!")
|
||||
// })
|
||||
|
||||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
// var wg sync.WaitGroup
|
||||
// singleRequest := func(wg *sync.WaitGroup) {
|
||||
// defer wg.Done()
|
||||
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
}
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
// }
|
||||
|
||||
for i := 0; i <= 49; i++ {
|
||||
wg.Add(1)
|
||||
go singleRequest(&wg)
|
||||
}
|
||||
// for i := 0; i <= 49; i++ {
|
||||
// wg.Add(1)
|
||||
// go singleRequest(&wg)
|
||||
// }
|
||||
|
||||
wg.Wait()
|
||||
// wg.Wait()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// req := httptest.NewRequest("GET", "/", nil)
|
||||
// resp, err := app.Test(req)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
|
||||
cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// cachedReq := httptest.NewRequest("GET", "/", nil)
|
||||
// cachedResp, err := app.Test(cachedReq)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// body, err := ioutil.ReadAll(resp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// cachedBody, err := ioutil.ReadAll(cachedResp.Body)
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
// utils.AssertEqual(t, cachedBody, body)
|
||||
// }
|
||||
|
||||
func Test_Cache_Invalid_Expiration(t *testing.T) {
|
||||
app := fiber.New()
|
||||
@ -235,7 +234,7 @@ func Test_Cache_NothingToCache(t *testing.T) {
|
||||
func Test_CustomKey(t *testing.T) {
|
||||
app := fiber.New()
|
||||
var called bool
|
||||
app.Use(New(Config{Key: func(c *fiber.Ctx) string {
|
||||
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
|
||||
called = true
|
||||
return c.Path()
|
||||
}}))
|
||||
@ -279,12 +278,12 @@ func Benchmark_Cache(b *testing.B) {
|
||||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Cache_Store -benchmem -count=4
|
||||
func Benchmark_Cache_Store(b *testing.B) {
|
||||
// go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4
|
||||
func Benchmark_Cache_Storage(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Store: testStore{stmap: map[string][]byte{}, mutex: &sync.RWMutex{}},
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
@ -308,43 +307,3 @@ func Benchmark_Cache_Store(b *testing.B) {
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
// testStore is used for testing custom stores
|
||||
type testStore struct {
|
||||
stmap map[string][]byte
|
||||
mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
func (s testStore) Get(id string) ([]byte, error) {
|
||||
s.mutex.RLock()
|
||||
val, ok := s.stmap[id]
|
||||
s.mutex.RUnlock()
|
||||
if !ok {
|
||||
return nil, nil
|
||||
} else {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s testStore) Set(id string, val []byte, _ time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
s.stmap[id] = val
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Reset() error {
|
||||
s.stmap = map[string][]byte{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Delete(id string) error {
|
||||
s.mutex.Lock()
|
||||
delete(s.stmap, id)
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
37
middleware/cache/config.go
vendored
37
middleware/cache/config.go
vendored
@ -29,19 +29,18 @@ type Config struct {
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.Path()
|
||||
// }
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// Deprecated, use Storage instead
|
||||
Store fiber.Storage
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// Internally used - if true, the simpler method of two maps is used in order to keep
|
||||
// execution time down.
|
||||
defaultStore bool
|
||||
// Deprecated, use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated, use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
@ -49,10 +48,10 @@ var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheControl: false,
|
||||
Key: func(c *fiber.Ctx) string {
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.Path()
|
||||
},
|
||||
defaultStore: true,
|
||||
Storage: nil,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
@ -66,22 +65,22 @@ func configDefault(config ...Config) Config {
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("[CACHE] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
fmt.Println("[CACHE] Key is deprecated, please use KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) == 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.Key == nil {
|
||||
cfg.Key = ConfigDefault.Key
|
||||
}
|
||||
if cfg.Storage == nil && cfg.Store == nil {
|
||||
cfg.defaultStore = true
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
fmt.Println("cache: `Store` is deprecated, use `Storage` instead")
|
||||
cfg.Storage = cfg.Store
|
||||
cfg.defaultStore = true
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
122
middleware/cache/manager.go
vendored
Normal file
122
middleware/cache/manager.go
vendored
Normal file
@ -0,0 +1,122 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type item struct {
|
||||
body []byte
|
||||
ctype []byte
|
||||
cencoding []byte
|
||||
status int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
// don't release item if we using memory storage
|
||||
if m.storage != nil {
|
||||
return
|
||||
}
|
||||
e.body = nil
|
||||
e.ctype = nil
|
||||
e.status = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
@ -30,10 +30,16 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
err = msgp.WrapError(err, "body")
|
||||
return
|
||||
}
|
||||
case "cType":
|
||||
z.cType, err = dc.ReadBytes(z.cType)
|
||||
case "ctype":
|
||||
z.ctype, err = dc.ReadBytes(z.ctype)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cType")
|
||||
err = msgp.WrapError(err, "ctype")
|
||||
return
|
||||
}
|
||||
case "cencoding":
|
||||
z.cencoding, err = dc.ReadBytes(z.cencoding)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cencoding")
|
||||
return
|
||||
}
|
||||
case "status":
|
||||
@ -60,10 +66,10 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 4
|
||||
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 5
|
||||
// write "body"
|
||||
err = en.Append(0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
err = en.Append(0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -72,14 +78,24 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
err = msgp.WrapError(err, "body")
|
||||
return
|
||||
}
|
||||
// write "cType"
|
||||
err = en.Append(0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
|
||||
// write "ctype"
|
||||
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteBytes(z.cType)
|
||||
err = en.WriteBytes(z.ctype)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cType")
|
||||
err = msgp.WrapError(err, "ctype")
|
||||
return
|
||||
}
|
||||
// write "cencoding"
|
||||
err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteBytes(z.cencoding)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cencoding")
|
||||
return
|
||||
}
|
||||
// write "status"
|
||||
@ -106,15 +122,18 @@ func (z *entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 4
|
||||
// map header, size 5
|
||||
// string "body"
|
||||
o = append(o, 0x84, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
o = append(o, 0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
o = msgp.AppendBytes(o, z.body)
|
||||
// string "cType"
|
||||
o = append(o, 0xa5, 0x63, 0x54, 0x79, 0x70, 0x65)
|
||||
o = msgp.AppendBytes(o, z.cType)
|
||||
// string "ctype"
|
||||
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
o = msgp.AppendBytes(o, z.ctype)
|
||||
// string "cencoding"
|
||||
o = append(o, 0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
|
||||
o = msgp.AppendBytes(o, z.cencoding)
|
||||
// string "status"
|
||||
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
|
||||
o = msgp.AppendInt(o, z.status)
|
||||
@ -125,7 +144,7 @@ func (z *entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
@ -148,10 +167,16 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
err = msgp.WrapError(err, "body")
|
||||
return
|
||||
}
|
||||
case "cType":
|
||||
z.cType, bts, err = msgp.ReadBytesBytes(bts, z.cType)
|
||||
case "ctype":
|
||||
z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cType")
|
||||
err = msgp.WrapError(err, "ctype")
|
||||
return
|
||||
}
|
||||
case "cencoding":
|
||||
z.cencoding, bts, err = msgp.ReadBytesBytes(bts, z.cencoding)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "cencoding")
|
||||
return
|
||||
}
|
||||
case "status":
|
||||
@ -179,7 +204,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *entry) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.cType) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
func (z *item) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
12
middleware/cache/store.go
vendored
12
middleware/cache/store.go
vendored
@ -1,12 +0,0 @@
|
||||
package cache
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type entry struct {
|
||||
body []byte `msg:"body"`
|
||||
cType []byte `msg:"cType"`
|
||||
status int `msg:"status"`
|
||||
exp uint64 `msg:"exp"`
|
||||
}
|
@ -23,9 +23,9 @@ _NOTE: This middleware uses our [Storage](https://github.com/gofiber/storage) pa
|
||||
func New(config ...Config) fiber.Handler
|
||||
```
|
||||
|
||||
## Examples
|
||||
### Examples
|
||||
|
||||
First import the middleware from Fiber,
|
||||
Import the middleware package that is part of the Fiber web framework
|
||||
|
||||
```go
|
||||
import (
|
||||
@ -34,11 +34,7 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
Then create a Fiber app with `app := fiber.New()`.
|
||||
|
||||
### Default Config
|
||||
|
||||
Then apply the middleware to your Fiber app,
|
||||
After you initiate your Fiber app, you can use the following possibilities:
|
||||
|
||||
```go
|
||||
app.Use(csrf.New()) // Default config
|
||||
@ -90,7 +86,7 @@ type Config struct {
|
||||
KeyLookup string
|
||||
|
||||
// Name of the session cookie. This cookie will store session key.
|
||||
// Optional. Default value "_csrf".
|
||||
// Optional. Default value "csrf_".
|
||||
CookieName string
|
||||
|
||||
// Domain of the CSRF cookie.
|
||||
@ -109,7 +105,7 @@ type Config struct {
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
|
||||
// Indicates if CSRF cookie is HTTP only.
|
||||
// Indicates if CSRF cookie is requested by SameSite.
|
||||
// Optional. Default value "Strict".
|
||||
CookieSameSite string
|
||||
|
||||
|
@ -28,7 +28,7 @@ type Config struct {
|
||||
KeyLookup string
|
||||
|
||||
// Name of the session cookie. This cookie will store session key.
|
||||
// Optional. Default value "_csrf".
|
||||
// Optional. Default value "csrf_".
|
||||
CookieName string
|
||||
|
||||
// Domain of the CSRF cookie.
|
||||
|
@ -2,13 +2,11 @@ package csrf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
@ -16,10 +14,8 @@ func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Set default values
|
||||
if cfg.Storage == nil {
|
||||
cfg.Storage = memory.New()
|
||||
}
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Generate the correct extractor to get the token from the correct location
|
||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||
@ -39,14 +35,10 @@ func New(config ...Config) fiber.Handler {
|
||||
case "param":
|
||||
extractor = csrfFromParam(selectors[1])
|
||||
case "cookie":
|
||||
if selectors[1] == cfg.CookieName {
|
||||
panic(fmt.Sprintf("KeyLookup key %s can't be the same as CookieName %s", selectors[1], cfg.CookieName))
|
||||
}
|
||||
extractor = csrfFromCookie(selectors[1])
|
||||
}
|
||||
|
||||
// We only use Keys in Storage, so we need a dummy value
|
||||
dummyVal := []byte{'+'}
|
||||
dummyValue := []byte{'+'}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) {
|
||||
@ -69,9 +61,7 @@ func New(config ...Config) fiber.Handler {
|
||||
token = cfg.KeyGenerator()
|
||||
|
||||
// Add token to Storage
|
||||
if err = cfg.Storage.Set(token, dummyVal, cfg.Expiration); err != nil {
|
||||
fmt.Println("[CSRF]", err.Error())
|
||||
}
|
||||
manager.setRaw(token, dummyValue, cfg.Expiration)
|
||||
}
|
||||
|
||||
// Create cookie to pass token to client
|
||||
@ -85,22 +75,19 @@ func New(config ...Config) fiber.Handler {
|
||||
HTTPOnly: cfg.CookieHTTPOnly,
|
||||
SameSite: cfg.CookieSameSite,
|
||||
}
|
||||
|
||||
// Set cookie to response
|
||||
c.Cookie(cookie)
|
||||
|
||||
case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut:
|
||||
// Verify CSRF token
|
||||
// Extract token from client request i.e. header, query, param, form or cookie
|
||||
token, err = extractor(c)
|
||||
if err != nil {
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
// We have a problem extracting the csrf token from Storage
|
||||
if _, err = cfg.Storage.Get(token); err != nil {
|
||||
// The token is invalid, let client generate a new one
|
||||
if err = cfg.Storage.Delete(token); err != nil {
|
||||
fmt.Println("[CSRF]", err.Error())
|
||||
}
|
||||
|
||||
// 403 if token does not exist in Storage
|
||||
if manager.getRaw(token) == nil {
|
||||
|
||||
// Expire cookie
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: cfg.CookieName,
|
||||
@ -111,8 +98,13 @@ func New(config ...Config) fiber.Handler {
|
||||
HTTPOnly: cfg.CookieHTTPOnly,
|
||||
SameSite: cfg.CookieSameSite,
|
||||
})
|
||||
|
||||
// Return 403 Forbidden
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
|
||||
// The token is validated, time to delete it
|
||||
manager.delete(token)
|
||||
}
|
||||
|
||||
// Protect clients from caching the response by telling the browser
|
||||
|
112
middleware/csrf/manager.go
Normal file
112
middleware/csrf/manager.go
Normal file
@ -0,0 +1,112 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type item struct {
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
// don't release item if we using memory storage
|
||||
if m.storage != nil {
|
||||
return
|
||||
}
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
90
middleware/csrf/manager_msgp.go
Normal file
90
middleware/csrf/manager_msgp.go
Normal file
@ -0,0 +1,90 @@
|
||||
package csrf
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 0
|
||||
err = en.Append(0x80)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 0
|
||||
o = append(o, 0x80)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1
|
||||
return
|
||||
}
|
@ -59,6 +59,11 @@ app.Get("/", func(c *fiber.Ctx) error {
|
||||
```go
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
@ -68,11 +73,6 @@ type Config struct {
|
||||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
```
|
||||
|
||||
@ -80,7 +80,7 @@ type Config struct {
|
||||
|
||||
```go
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
Weak: false,
|
||||
}
|
||||
```
|
||||
|
44
middleware/etag/config.go
Normal file
44
middleware/etag/config.go
Normal file
@ -0,0 +1,44 @@
|
||||
package etag
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
// to generate efficiently. Weak ETag values of two representations
|
||||
// of the same resources might be semantically equivalent, but not
|
||||
// byte-for-byte identical. This means weak etags prevent caching
|
||||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
return cfg
|
||||
}
|
@ -8,42 +8,13 @@ import (
|
||||
"github.com/gofiber/fiber/v2/internal/bytebufferpool"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
// to generate efficiently. Weak ETag values of two representations
|
||||
// of the same resources might be semantically equivalent, but not
|
||||
// byte-for-byte identical. This means weak etags prevent caching
|
||||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
var normalizedHeaderETag = []byte("Etag")
|
||||
var weakPrefix = []byte("W/")
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
cfg := configDefault(config...)
|
||||
|
||||
var crc32q = crc32.MakeTable(0xD5828281)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -26,16 +25,16 @@ func New(config ...Config) fiber.Handler {
|
||||
cfg := configDefault(config...)
|
||||
|
||||
var (
|
||||
// Limiter settings
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
timestamp = uint64(time.Now().Unix())
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
mux = &sync.RWMutex{}
|
||||
|
||||
// Default store logic (if no Store is provided)
|
||||
entries = make(map[string]entry)
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
go func() {
|
||||
for {
|
||||
@ -54,65 +53,39 @@ func New(config ...Config) fiber.Handler {
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Create new entry
|
||||
entry := entry{}
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
// Use Storage if provided
|
||||
if cfg.Storage != nil {
|
||||
val, err := cfg.Storage.Get(key)
|
||||
if val != nil && len(val) > 0 {
|
||||
if _, err := entry.UnmarshalMsg(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err != nil && err.Error() != errNotExist {
|
||||
fmt.Println("[LIMITER]", err.Error())
|
||||
}
|
||||
} else {
|
||||
entry = entries[key]
|
||||
}
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := atomic.LoadUint64(×tamp)
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if entry.exp == 0 {
|
||||
entry.exp = ts + expiration
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
|
||||
} else if ts >= entry.exp {
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
entry.hits = 0
|
||||
entry.exp = ts + expiration
|
||||
e.hits = 0
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
entry.hits++
|
||||
|
||||
// Use Storage if provided
|
||||
if cfg.Storage != nil {
|
||||
// Marshal entry to bytes
|
||||
val, err := entry.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass value to Storage
|
||||
if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
entries[key] = entry
|
||||
}
|
||||
e.hits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
expire := entry.exp - ts
|
||||
expire := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - entry.hits
|
||||
remaining := cfg.Max - e.hits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -172,7 +171,6 @@ func Test_Limiter_Headers(t *testing.T) {
|
||||
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
|
||||
}
|
||||
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
|
||||
fmt.Println(v)
|
||||
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
|
||||
}
|
||||
}
|
||||
|
115
middleware/limiter/manager.go
Normal file
115
middleware/limiter/manager.go
Normal file
@ -0,0 +1,115 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type item struct {
|
||||
hits int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item)
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
e.hits = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) (it *item) {
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
if raw, _ := m.storage.Get(key); raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil {
|
||||
it = m.acquire()
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) (raw []byte) {
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key)
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) delete(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key)
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
@ -48,7 +48,7 @@ func (z *entry) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 2
|
||||
// write "hits"
|
||||
err = en.Append(0x82, 0xa4, 0x68, 0x69, 0x74, 0x73)
|
||||
@ -74,7 +74,7 @@ func (z entry) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 2
|
||||
// string "hits"
|
||||
@ -87,7 +87,7 @@ func (z entry) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
@ -129,7 +129,7 @@ func (z *entry) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z entry) Msgsize() (s int) {
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
package limiter
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type entry struct {
|
||||
hits int `msg:"hits"`
|
||||
exp uint64 `msg:"exp"`
|
||||
}
|
@ -129,7 +129,7 @@ func New(config ...Config) fiber.Handler {
|
||||
|
||||
// Set error handler once
|
||||
once.Do(func() {
|
||||
errHandler = c.App().Config().ErrorHandler
|
||||
// get longested possible path
|
||||
stack := c.App().Stack()
|
||||
for m := range stack {
|
||||
for r := range stack[m] {
|
||||
@ -139,7 +139,8 @@ func New(config ...Config) fiber.Handler {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// override error handler
|
||||
errHandler = c.App().Config().ErrorHandler
|
||||
})
|
||||
|
||||
// Set latency start time
|
||||
|
@ -14,13 +14,25 @@ _NOTE: This middleware uses our [Storage](https://github.com/gofiber/storage) pa
|
||||
## Signatures
|
||||
|
||||
```go
|
||||
func New(config ...Config) fiber.Handler
|
||||
func New(config ...Config) *Store
|
||||
func (s *Store) RegisterType(i interface{})
|
||||
func (s *Store) Get(c *fiber.Ctx) (*Session, error)
|
||||
func (s *Store) Reset() error
|
||||
|
||||
func (s *Session) Get(key string) interface{}
|
||||
func (s *Session) Set(key string, val interface{})
|
||||
func (s *Session) Delete(key string)
|
||||
func (s *Session) Destroy() error
|
||||
func (s *Session) Regenerate() error
|
||||
func (s *Session) Save() error
|
||||
func (s *Session) Fresh() bool
|
||||
func (s *Session) ID() string
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
First import the middleware from Fiber,
|
||||
**⚠ _Storing `interface{}` values are limited to built-ins Go types_**
|
||||
|
||||
### Examples
|
||||
Import the middleware package that is part of the Fiber web framework
|
||||
```go
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@ -45,9 +57,6 @@ app.Get("/", func(c *fiber.Ctx) error {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// save session
|
||||
defer sess.Save()
|
||||
|
||||
// Get value
|
||||
name := sess.Get("name")
|
||||
|
||||
@ -62,6 +71,11 @@ app.Get("/", func(c *fiber.Ctx) error {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// save session
|
||||
if err := sess.Save(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return fmt.Fprintf(ctx, "Welcome %v", name)
|
||||
})
|
||||
```
|
||||
|
@ -42,7 +42,7 @@ type Config struct {
|
||||
CookieSameSite string
|
||||
|
||||
// KeyGenerator generates the session key.
|
||||
// Optional. Default value utils.UUID
|
||||
// Optional. Default value utils.UUIDv4
|
||||
KeyGenerator func() string
|
||||
}
|
||||
|
||||
@ -50,7 +50,7 @@ type Config struct {
|
||||
var ConfigDefault = Config{
|
||||
Expiration: 24 * time.Hour,
|
||||
CookieName: "session_id",
|
||||
KeyGenerator: utils.UUID,
|
||||
KeyGenerator: utils.UUIDv4,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
@ -67,6 +67,9 @@ func configDefault(config ...Config) Config {
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.CookieName == "" {
|
||||
cfg.CookieName = ConfigDefault.CookieName
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
|
62
middleware/session/data.go
Normal file
62
middleware/session/data.go
Normal file
@ -0,0 +1,62 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="data.go" -o="data_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type data struct {
|
||||
sync.RWMutex `gotiny:"-"`
|
||||
d map[string]interface{} `gotiny:"d"`
|
||||
}
|
||||
|
||||
var dataPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
d := new(data)
|
||||
d.d = make(map[string]interface{})
|
||||
return d
|
||||
},
|
||||
}
|
||||
|
||||
func acquireData() *data {
|
||||
return dataPool.Get().(*data)
|
||||
}
|
||||
|
||||
func releaseData(d *data) {
|
||||
d.Reset()
|
||||
dataPool.Put(d)
|
||||
}
|
||||
|
||||
func (d *data) Reset() {
|
||||
d.Lock()
|
||||
for key := range d.d {
|
||||
delete(d.d, key)
|
||||
}
|
||||
d.Unlock()
|
||||
}
|
||||
|
||||
func (d *data) Get(key string) interface{} {
|
||||
d.RLock()
|
||||
v := d.d[key]
|
||||
d.RUnlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (d *data) Set(key string, value interface{}) {
|
||||
d.Lock()
|
||||
d.d[key] = value
|
||||
d.Unlock()
|
||||
}
|
||||
|
||||
func (d *data) Delete(key string) {
|
||||
d.Lock()
|
||||
delete(d.d, key)
|
||||
d.Unlock()
|
||||
}
|
||||
|
||||
func (d *data) Len() int {
|
||||
return len(d.d)
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
package session
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="db.go" -o="db_msgp.go" -tests=false -unexported
|
||||
// don't forget to replace the msgp import path to:
|
||||
// "github.com/gofiber/fiber/v2/internal/msgp"
|
||||
type db struct {
|
||||
d []kv
|
||||
}
|
||||
|
||||
// go:generate msgp
|
||||
type kv struct {
|
||||
k string
|
||||
v interface{}
|
||||
}
|
||||
|
||||
func (d *db) Reset() {
|
||||
d.d = d.d[:0]
|
||||
}
|
||||
|
||||
func (d *db) Get(key string) interface{} {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
return d.d[idx].v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *db) Set(key string, value interface{}) {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
kv := &d.d[idx]
|
||||
kv.v = value
|
||||
} else {
|
||||
d.append(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *db) Delete(key string) {
|
||||
idx := d.indexOf(key)
|
||||
if idx > -1 {
|
||||
n := len(d.d) - 1
|
||||
d.swap(idx, n)
|
||||
d.d = d.d[:n]
|
||||
}
|
||||
}
|
||||
|
||||
func (d *db) Len() int {
|
||||
return len(d.d)
|
||||
}
|
||||
|
||||
func (d *db) swap(i, j int) {
|
||||
iKey, iValue := d.d[i].k, d.d[i].v
|
||||
jKey, jValue := d.d[j].k, d.d[j].v
|
||||
|
||||
d.d[i].k, d.d[i].v = jKey, jValue
|
||||
d.d[j].k, d.d[j].v = iKey, iValue
|
||||
}
|
||||
|
||||
func (d *db) allocPage() *kv {
|
||||
n := len(d.d)
|
||||
if cap(d.d) > n {
|
||||
d.d = d.d[:n+1]
|
||||
} else {
|
||||
d.d = append(d.d, kv{})
|
||||
}
|
||||
return &d.d[n]
|
||||
}
|
||||
|
||||
func (d *db) append(key string, value interface{}) {
|
||||
kv := d.allocPage()
|
||||
kv.k = key
|
||||
kv.v = value
|
||||
}
|
||||
|
||||
func (d *db) indexOf(key string) int {
|
||||
n := len(d.d)
|
||||
for i := 0; i < n; i++ {
|
||||
if d.d[i].k == key {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
@ -1,365 +0,0 @@
|
||||
package session
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *db) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "d":
|
||||
var zb0002 uint32
|
||||
zb0002, err = dc.ReadArrayHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d")
|
||||
return
|
||||
}
|
||||
if cap(z.d) >= int(zb0002) {
|
||||
z.d = (z.d)[:zb0002]
|
||||
} else {
|
||||
z.d = make([]kv, zb0002)
|
||||
}
|
||||
for za0001 := range z.d {
|
||||
var zb0003 uint32
|
||||
zb0003, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
for zb0003 > 0 {
|
||||
zb0003--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.d[za0001].k, err = dc.ReadString()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.d[za0001].v, err = dc.ReadIntf()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *db) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 1
|
||||
// write "d"
|
||||
err = en.Append(0x81, 0xa1, 0x64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteArrayHeader(uint32(len(z.d)))
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d")
|
||||
return
|
||||
}
|
||||
for za0001 := range z.d {
|
||||
// map header, size 2
|
||||
// write "k"
|
||||
err = en.Append(0x82, 0xa1, 0x6b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteString(z.d[za0001].k)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "k")
|
||||
return
|
||||
}
|
||||
// write "v"
|
||||
err = en.Append(0xa1, 0x76)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteIntf(z.d[za0001].v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *db) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 1
|
||||
// string "d"
|
||||
o = append(o, 0x81, 0xa1, 0x64)
|
||||
o = msgp.AppendArrayHeader(o, uint32(len(z.d)))
|
||||
for za0001 := range z.d {
|
||||
// map header, size 2
|
||||
// string "k"
|
||||
o = append(o, 0x82, 0xa1, 0x6b)
|
||||
o = msgp.AppendString(o, z.d[za0001].k)
|
||||
// string "v"
|
||||
o = append(o, 0xa1, 0x76)
|
||||
o, err = msgp.AppendIntf(o, z.d[za0001].v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "d":
|
||||
var zb0002 uint32
|
||||
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d")
|
||||
return
|
||||
}
|
||||
if cap(z.d) >= int(zb0002) {
|
||||
z.d = (z.d)[:zb0002]
|
||||
} else {
|
||||
z.d = make([]kv, zb0002)
|
||||
}
|
||||
for za0001 := range z.d {
|
||||
var zb0003 uint32
|
||||
zb0003, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
for zb0003 > 0 {
|
||||
zb0003--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.d[za0001].k, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.d[za0001].v, bts, err = msgp.ReadIntfBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "d", za0001)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *db) Msgsize() (s int) {
|
||||
s = 1 + 2 + msgp.ArrayHeaderSize
|
||||
for za0001 := range z.d {
|
||||
s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *kv) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.k, err = dc.ReadString()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.v, err = dc.ReadIntf()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z kv) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 2
|
||||
// write "k"
|
||||
err = en.Append(0x82, 0xa1, 0x6b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteString(z.k)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "k")
|
||||
return
|
||||
}
|
||||
// write "v"
|
||||
err = en.Append(0xa1, 0x76)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteIntf(z.v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z kv) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 2
|
||||
// string "k"
|
||||
o = append(o, 0x82, 0xa1, 0x6b)
|
||||
o = msgp.AppendString(o, z.k)
|
||||
// string "v"
|
||||
o = append(o, 0xa1, 0x76)
|
||||
o, err = msgp.AppendIntf(o, z.v)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *kv) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "k":
|
||||
z.k, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "k")
|
||||
return
|
||||
}
|
||||
case "v":
|
||||
z.v, bts, err = msgp.ReadIntfBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "v")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z kv) Msgsize() (s int) {
|
||||
s = 1 + 2 + msgp.StringPrefixSize + len(z.k) + 2 + msgp.GuessSize(z.v)
|
||||
return
|
||||
}
|
@ -5,16 +5,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/gotiny"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ctx *fiber.Ctx
|
||||
config *Store
|
||||
db *db
|
||||
id string
|
||||
fresh bool
|
||||
id string // session id
|
||||
fresh bool // if new session
|
||||
ctx *fiber.Ctx // fiber context
|
||||
config *Store // store configuration
|
||||
data *data // key value data
|
||||
}
|
||||
|
||||
var sessionPool = sync.Pool{
|
||||
@ -25,19 +26,20 @@ var sessionPool = sync.Pool{
|
||||
|
||||
func acquireSession() *Session {
|
||||
s := sessionPool.Get().(*Session)
|
||||
s.db = new(db)
|
||||
if s.data == nil {
|
||||
s.data = acquireData()
|
||||
}
|
||||
s.fresh = true
|
||||
return s
|
||||
}
|
||||
|
||||
func releaseSession(s *Session) {
|
||||
s.id = ""
|
||||
s.ctx = nil
|
||||
s.config = nil
|
||||
if s.db != nil {
|
||||
s.db.Reset()
|
||||
if s.data != nil {
|
||||
s.data.Reset()
|
||||
}
|
||||
s.id = ""
|
||||
s.fresh = true
|
||||
sessionPool.Put(s)
|
||||
}
|
||||
|
||||
@ -53,25 +55,42 @@ func (s *Session) ID() string {
|
||||
|
||||
// Get will return the value
|
||||
func (s *Session) Get(key string) interface{} {
|
||||
return s.db.Get(key)
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
return s.data.Get(key)
|
||||
}
|
||||
|
||||
// Set will update or create a new key value
|
||||
func (s *Session) Set(key string, val interface{}) {
|
||||
s.db.Set(key, val)
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return
|
||||
}
|
||||
s.data.Set(key, val)
|
||||
}
|
||||
|
||||
// Delete will delete the value
|
||||
func (s *Session) Delete(key string) {
|
||||
s.db.Delete(key)
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return
|
||||
}
|
||||
s.data.Delete(key)
|
||||
}
|
||||
|
||||
// Destroy will delete the session from Storage and expire session cookie
|
||||
func (s *Session) Destroy() error {
|
||||
// Reset local data
|
||||
s.db.Reset()
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete data from storage
|
||||
// Reset local data
|
||||
s.data.Reset()
|
||||
|
||||
// Use external Storage if exist
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -88,6 +107,7 @@ func (s *Session) Regenerate() error {
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create new ID
|
||||
s.id = s.config.KeyGenerator()
|
||||
|
||||
@ -96,26 +116,34 @@ func (s *Session) Regenerate() error {
|
||||
|
||||
// Save will update the storage and client cookie
|
||||
func (s *Session) Save() error {
|
||||
// Don't save to Storage if no data is available
|
||||
if s.db.Len() <= 0 {
|
||||
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert book to bytes
|
||||
data, err := s.db.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
// Create cookie with the session ID if fresh
|
||||
if s.fresh {
|
||||
s.setCookie()
|
||||
}
|
||||
|
||||
// Don't save to Storage if no data is available
|
||||
if s.data.Len() <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert data to bytes
|
||||
mux.Lock()
|
||||
data := gotiny.Marshal(&s.data)
|
||||
mux.Unlock()
|
||||
|
||||
// pass raw bytes with session id to provider
|
||||
if err := s.config.Storage.Set(s.id, data, s.config.Expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create cookie with the session ID
|
||||
s.setCookie()
|
||||
|
||||
// release session to pool to be re-used on next request
|
||||
// Release session
|
||||
// TODO: It's not safe to use the Session after called Save()
|
||||
releaseSession(s)
|
||||
|
||||
return nil
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -66,6 +67,108 @@ func Test_Session(t *testing.T) {
|
||||
utils.AssertEqual(t, 36, len(id))
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Types
|
||||
func Test_Session_Types(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// session store
|
||||
store := New()
|
||||
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.CookieName, "123")
|
||||
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
}
|
||||
var vuser = User{
|
||||
Name: "John",
|
||||
}
|
||||
// set value
|
||||
var vbool bool = true
|
||||
var vstring string = "str"
|
||||
var vint int = 13
|
||||
var vint8 int8 = 13
|
||||
var vint16 int16 = 13
|
||||
var vint32 int32 = 13
|
||||
var vint64 int64 = 13
|
||||
var vuint uint = 13
|
||||
var vuint8 uint8 = 13
|
||||
var vuint16 uint16 = 13
|
||||
var vuint32 uint32 = 13
|
||||
var vuint64 uint64 = 13
|
||||
var vuintptr uintptr = 13
|
||||
var vbyte byte = 'k'
|
||||
var vrune rune = 'k'
|
||||
var vfloat32 float32 = 13
|
||||
var vfloat64 float64 = 13
|
||||
var vcomplex64 complex64 = 13
|
||||
var vcomplex128 complex128 = 13
|
||||
sess.Set("vuser", vuser)
|
||||
sess.Set("vbool", vbool)
|
||||
sess.Set("vstring", vstring)
|
||||
sess.Set("vint", vint)
|
||||
sess.Set("vint8", vint8)
|
||||
sess.Set("vint16", vint16)
|
||||
sess.Set("vint32", vint32)
|
||||
sess.Set("vint64", vint64)
|
||||
sess.Set("vuint", vuint)
|
||||
sess.Set("vuint8", vuint8)
|
||||
sess.Set("vuint16", vuint16)
|
||||
sess.Set("vuint32", vuint32)
|
||||
sess.Set("vuint32", vuint32)
|
||||
sess.Set("vuint64", vuint64)
|
||||
sess.Set("vuintptr", vuintptr)
|
||||
sess.Set("vbyte", vbyte)
|
||||
sess.Set("vrune", vrune)
|
||||
sess.Set("vfloat32", vfloat32)
|
||||
sess.Set("vfloat64", vfloat64)
|
||||
sess.Set("vcomplex64", vcomplex64)
|
||||
sess.Set("vcomplex128", vcomplex128)
|
||||
|
||||
// save session
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// get session
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
|
||||
// get value
|
||||
utils.AssertEqual(t, vuser, sess.Get("vuser").(User))
|
||||
utils.AssertEqual(t, vbool, sess.Get("vbool").(bool))
|
||||
utils.AssertEqual(t, vstring, sess.Get("vstring").(string))
|
||||
utils.AssertEqual(t, vint, sess.Get("vint").(int))
|
||||
utils.AssertEqual(t, vint8, sess.Get("vint8").(int8))
|
||||
utils.AssertEqual(t, vint16, sess.Get("vint16").(int16))
|
||||
utils.AssertEqual(t, vint32, sess.Get("vint32").(int32))
|
||||
utils.AssertEqual(t, vint64, sess.Get("vint64").(int64))
|
||||
utils.AssertEqual(t, vuint, sess.Get("vuint").(uint))
|
||||
utils.AssertEqual(t, vuint8, sess.Get("vuint8").(uint8))
|
||||
utils.AssertEqual(t, vuint16, sess.Get("vuint16").(uint16))
|
||||
utils.AssertEqual(t, vuint32, sess.Get("vuint32").(uint32))
|
||||
utils.AssertEqual(t, vuint64, sess.Get("vuint64").(uint64))
|
||||
utils.AssertEqual(t, vuintptr, sess.Get("vuintptr").(uintptr))
|
||||
utils.AssertEqual(t, vbyte, sess.Get("vbyte").(byte))
|
||||
utils.AssertEqual(t, vrune, sess.Get("vrune").(rune))
|
||||
utils.AssertEqual(t, vfloat32, sess.Get("vfloat32").(float32))
|
||||
utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
|
||||
utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
|
||||
utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Store_Reset
|
||||
func Test_Session_Store_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -167,6 +270,37 @@ func Test_Session_Cookie(t *testing.T) {
|
||||
sess, _ := store.Get(ctx)
|
||||
sess.Save()
|
||||
|
||||
// cookie should not be set if empty data
|
||||
utils.AssertEqual(t, 0, len(ctx.Response().Header.PeekCookie(store.CookieName)))
|
||||
// cookie should be set on Save ( even if empty data )
|
||||
utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.CookieName)))
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
|
||||
func Benchmark_Session(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.CookieName, "12356789")
|
||||
|
||||
b.Run("default", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c)
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
store = New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c)
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -1,7 +1,10 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/gotiny"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
)
|
||||
|
||||
@ -9,8 +12,7 @@ type Store struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Storage ErrNotExist
|
||||
var errNotExist = "key does not exist"
|
||||
var mux sync.Mutex
|
||||
|
||||
func New(config ...Config) *Store {
|
||||
// Set default config
|
||||
@ -25,6 +27,13 @@ func New(config ...Config) *Store {
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterType will allow you to encode/decode custom types
|
||||
// into any Storage provider
|
||||
func (s *Store) RegisterType(i interface{}) {
|
||||
gotiny.Register(i)
|
||||
}
|
||||
|
||||
// Get will get/create a session
|
||||
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
||||
var fresh bool
|
||||
|
||||
@ -42,19 +51,21 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
||||
sess.ctx = c
|
||||
sess.config = s
|
||||
sess.id = id
|
||||
sess.fresh = fresh
|
||||
|
||||
// Fetch existing data
|
||||
if !fresh {
|
||||
raw, err := s.Storage.Get(id)
|
||||
// Unmashal if we found data
|
||||
if err == nil {
|
||||
if _, err = sess.db.UnmarshalMsg(raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw != nil && err == nil {
|
||||
mux.Lock()
|
||||
gotiny.Unmarshal(raw, &sess.data)
|
||||
mux.Unlock()
|
||||
sess.fresh = false
|
||||
} else if err.Error() != errNotExist {
|
||||
// Only return error if it's not ErrNotExist
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
sess.fresh = true
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -349,9 +349,7 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router {
|
||||
if maxAge > 0 {
|
||||
cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
|
||||
}
|
||||
if config[0].CacheDuration != 0 {
|
||||
fs.CacheDuration = config[0].CacheDuration
|
||||
}
|
||||
fs.CacheDuration = config[0].CacheDuration
|
||||
fs.Compress = config[0].Compress
|
||||
fs.AcceptByteRange = config[0].ByteRange
|
||||
fs.GenerateIndexPages = config[0].Browse
|
||||
|
@ -16,17 +16,19 @@ import (
|
||||
)
|
||||
|
||||
// AssertEqual checks if values are equal
|
||||
func AssertEqual(t testing.TB, expected interface{}, actual interface{}, description ...string) {
|
||||
func AssertEqual(t testing.TB, expected, actual interface{}, description ...string) {
|
||||
if reflect.DeepEqual(expected, actual) {
|
||||
return
|
||||
}
|
||||
|
||||
var aType = "<nil>"
|
||||
var bType = "<nil>"
|
||||
if reflect.ValueOf(expected).IsValid() {
|
||||
aType = reflect.TypeOf(expected).Name()
|
||||
|
||||
if expected != nil {
|
||||
aType = fmt.Sprintf("%s", reflect.TypeOf(expected))
|
||||
}
|
||||
if reflect.ValueOf(actual).IsValid() {
|
||||
bType = reflect.TypeOf(actual).Name()
|
||||
if actual != nil {
|
||||
bType = fmt.Sprintf("%s", reflect.TypeOf(actual))
|
||||
}
|
||||
|
||||
testName := "AssertEqual"
|
||||
@ -40,13 +42,11 @@ func AssertEqual(t testing.TB, expected interface{}, actual interface{}, descrip
|
||||
w := tabwriter.NewWriter(&buf, 0, 0, 5, ' ', 0)
|
||||
fmt.Fprintf(w, "\nTest:\t%s", testName)
|
||||
fmt.Fprintf(w, "\nTrace:\t%s:%d", filepath.Base(file), line)
|
||||
fmt.Fprintf(w, "\nError:\tNot equal")
|
||||
fmt.Fprintf(w, "\nExpect:\t%v\t[%s]", expected, aType)
|
||||
fmt.Fprintf(w, "\nResult:\t%v\t[%s]", actual, bType)
|
||||
|
||||
if len(description) > 0 {
|
||||
fmt.Fprintf(w, "\nDescription:\t%s", description[0])
|
||||
}
|
||||
fmt.Fprintf(w, "\nExpect:\t%v\t(%s)", expected, aType)
|
||||
fmt.Fprintf(w, "\nResult:\t%v\t(%s)", actual, bType)
|
||||
|
||||
result := ""
|
||||
if err := w.Flush(); err != nil {
|
||||
@ -54,6 +54,7 @@ func AssertEqual(t testing.TB, expected interface{}, actual interface{}, descrip
|
||||
} else {
|
||||
result = buf.String()
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
t.Fatal(result)
|
||||
} else {
|
||||
|
@ -6,7 +6,7 @@ package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_Utils_AssertEqual(t *testing.T) {
|
||||
func Test_AssertEqual(t *testing.T) {
|
||||
t.Parallel()
|
||||
AssertEqual(nil, []string{}, []string{})
|
||||
AssertEqual(t, []string{}, []string{})
|
||||
|
@ -56,7 +56,7 @@ func TrimBytes(b []byte, cutset byte) []byte {
|
||||
}
|
||||
|
||||
// EqualFold the equivalent of bytes.EqualFold
|
||||
func EqualsFold(b, s []byte) (equals bool) {
|
||||
func EqualFoldBytes(b, s []byte) (equals bool) {
|
||||
n := len(b)
|
||||
equals = n == len(s)
|
||||
if equals {
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_ToLowerBytes(t *testing.T) {
|
||||
func Test_ToLowerBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToLowerBytes([]byte("/MY/NAME/IS/:PARAM/*"))
|
||||
AssertEqual(t, true, bytes.Equal([]byte("/my/name/is/:param/*"), res))
|
||||
@ -41,7 +41,7 @@ func Benchmark_ToLowerBytes(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_ToUpperBytes(t *testing.T) {
|
||||
func Test_ToUpperBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToUpperBytes([]byte("/my/name/is/:param/*"))
|
||||
AssertEqual(t, true, bytes.Equal([]byte("/MY/NAME/IS/:PARAM/*"), res))
|
||||
@ -73,7 +73,7 @@ func Benchmark_ToUpperBytes(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimRightBytes(t *testing.T) {
|
||||
func Test_TrimRightBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimRightBytes([]byte("/test//////"), '/')
|
||||
AssertEqual(t, []byte("/test"), res)
|
||||
@ -99,7 +99,7 @@ func Benchmark_TrimRightBytes(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimLeftBytes(t *testing.T) {
|
||||
func Test_TrimLeftBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimLeftBytes([]byte("////test/"), '/')
|
||||
AssertEqual(t, []byte("test/"), res)
|
||||
@ -123,7 +123,7 @@ func Benchmark_TrimLeftBytes(b *testing.B) {
|
||||
AssertEqual(b, []byte("foobar"), res)
|
||||
})
|
||||
}
|
||||
func Test_Utils_TrimBytes(t *testing.T) {
|
||||
func Test_TrimBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimBytes([]byte(" test "), ' ')
|
||||
AssertEqual(t, []byte("test"), res)
|
||||
@ -151,14 +151,14 @@ func Benchmark_TrimBytes(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark_EqualFolds(b *testing.B) {
|
||||
func Benchmark_EqualFoldBytes(b *testing.B) {
|
||||
var left = []byte("/RePos/GoFiBer/FibEr/iSsues/187643/CoMmEnts")
|
||||
var right = []byte("/RePos/goFiber/Fiber/issues/187643/COMMENTS")
|
||||
var res bool
|
||||
|
||||
b.Run("fiber", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = EqualsFold(left, right)
|
||||
res = EqualFoldBytes(left, right)
|
||||
}
|
||||
AssertEqual(b, true, res)
|
||||
})
|
||||
@ -170,18 +170,18 @@ func Benchmark_EqualFolds(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_EqualsFold(t *testing.T) {
|
||||
func Test_EqualFoldBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := EqualsFold([]byte("/MY/NAME/IS/:PARAM/*"), []byte("/my/name/is/:param/*"))
|
||||
res := EqualFoldBytes([]byte("/MY/NAME/IS/:PARAM/*"), []byte("/my/name/is/:param/*"))
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualsFold([]byte("/MY1/NAME/IS/:PARAM/*"), []byte("/MY1/NAME/IS/:PARAM/*"))
|
||||
res = EqualFoldBytes([]byte("/MY1/NAME/IS/:PARAM/*"), []byte("/MY1/NAME/IS/:PARAM/*"))
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualsFold([]byte("/my2/name/is/:param/*"), []byte("/my2/name"))
|
||||
res = EqualFoldBytes([]byte("/my2/name/is/:param/*"), []byte("/my2/name"))
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualsFold([]byte("/dddddd"), []byte("eeeeee"))
|
||||
res = EqualFoldBytes([]byte("/dddddd"), []byte("eeeeee"))
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualsFold([]byte("/MY3/NAME/IS/:PARAM/*"), []byte("/my3/name/is/:param/*"))
|
||||
res = EqualFoldBytes([]byte("/MY3/NAME/IS/:PARAM/*"), []byte("/my3/name/is/:param/*"))
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualsFold([]byte("/MY4/NAME/IS/:PARAM/*"), []byte("/my4/nAME/IS/:param/*"))
|
||||
res = EqualFoldBytes([]byte("/MY4/NAME/IS/:PARAM/*"), []byte("/my4/nAME/IS/:param/*"))
|
||||
AssertEqual(t, true, res)
|
||||
}
|
||||
|
@ -13,6 +13,8 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
googleuuid "github.com/gofiber/fiber/v2/internal/uuid"
|
||||
)
|
||||
|
||||
const toLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff"
|
||||
@ -63,6 +65,16 @@ func UUID() string {
|
||||
return GetString(b)
|
||||
}
|
||||
|
||||
// UUIDv4 returns a Random (Version 4) UUID.
|
||||
// The strength of the UUIDs is based on the strength of the crypto/rand package.
|
||||
func UUIDv4() string {
|
||||
token, err := googleuuid.NewRandom()
|
||||
if err != nil {
|
||||
return UUID()
|
||||
}
|
||||
return token.String()
|
||||
}
|
||||
|
||||
// FunctionName returns function name
|
||||
func FunctionName(fn interface{}) string {
|
||||
t := reflect.ValueOf(fn).Type()
|
||||
|
@ -10,26 +10,26 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_FunctionName(t *testing.T) {
|
||||
func Test_FunctionName(t *testing.T) {
|
||||
t.Parallel()
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_Utils_UUID", FunctionName(Test_Utils_UUID))
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_UUID", FunctionName(Test_UUID))
|
||||
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_Utils_FunctionName.func1", FunctionName(func() {}))
|
||||
AssertEqual(t, "github.com/gofiber/fiber/v2/utils.Test_FunctionName.func1", FunctionName(func() {}))
|
||||
|
||||
var dummyint = 20
|
||||
AssertEqual(t, "int", FunctionName(dummyint))
|
||||
}
|
||||
|
||||
func Test_Utils_UUID(t *testing.T) {
|
||||
func Test_UUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := UUID()
|
||||
AssertEqual(t, 36, len(res))
|
||||
AssertEqual(t, true, res != "00000000-0000-0000-0000-000000000000")
|
||||
}
|
||||
|
||||
func Test_Utils_UUID_Concurrency(t *testing.T) {
|
||||
func Test_UUID_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
iterations := 10000
|
||||
iterations := 1000
|
||||
var res string
|
||||
ch := make(chan string, iterations)
|
||||
results := make(map[string]string)
|
||||
@ -45,6 +45,30 @@ func Test_Utils_UUID_Concurrency(t *testing.T) {
|
||||
AssertEqual(t, iterations, len(results))
|
||||
}
|
||||
|
||||
func Test_UUIDv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := UUIDv4()
|
||||
AssertEqual(t, 36, len(res))
|
||||
AssertEqual(t, true, res != "00000000-0000-0000-0000-000000000000")
|
||||
}
|
||||
func Test_UUIDv4_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
iterations := 1000
|
||||
var res string
|
||||
ch := make(chan string, iterations)
|
||||
results := make(map[string]string)
|
||||
for i := 0; i < iterations; i++ {
|
||||
go func() {
|
||||
ch <- UUIDv4()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < iterations; i++ {
|
||||
res = <-ch
|
||||
results[res] = res
|
||||
}
|
||||
AssertEqual(t, iterations, len(results))
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_UUID -benchmem -count=2
|
||||
|
||||
func Benchmark_UUID(b *testing.B) {
|
||||
|
@ -28,13 +28,13 @@ func UnsafeBytes(s string) (bs []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// SafeString copies a string to make it immutable
|
||||
func SafeString(s string) string {
|
||||
// CopyString copies a string to make it immutable
|
||||
func CopyString(s string) string {
|
||||
return string(UnsafeBytes(s))
|
||||
}
|
||||
|
||||
// SafeBytes copies a slice to make it immutable
|
||||
func SafeBytes(b []byte) []byte {
|
||||
// CopyBytes copies a slice to make it immutable
|
||||
func CopyBytes(b []byte) []byte {
|
||||
tmp := make([]byte, len(b))
|
||||
copy(tmp, b)
|
||||
return tmp
|
||||
@ -83,22 +83,3 @@ func ByteSize(bytes uint64) string {
|
||||
result = strings.TrimSuffix(result, ".0")
|
||||
return result + unit
|
||||
}
|
||||
|
||||
// Deprecated fn's
|
||||
|
||||
// #nosec G103
|
||||
// GetString returns a string pointer without allocation
|
||||
func GetString(b []byte) string {
|
||||
return UnsafeString(b)
|
||||
}
|
||||
|
||||
// #nosec G103
|
||||
// GetBytes returns a byte pointer without allocation
|
||||
func GetBytes(s string) []byte {
|
||||
return UnsafeBytes(s)
|
||||
}
|
||||
|
||||
// ImmutableString copies a string to make it immutable
|
||||
func ImmutableString(s string) string {
|
||||
return SafeString(s)
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_Utils_GetString(t *testing.T) {
|
||||
func Test_GetString(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := GetString([]byte("Hello, World!"))
|
||||
AssertEqual(t, "Hello, World!", res)
|
||||
@ -31,7 +31,7 @@ func Benchmark_GetString(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_GetBytes(t *testing.T) {
|
||||
func Test_GetBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := GetBytes("Hello, World!")
|
||||
AssertEqual(t, []byte("Hello, World!"), res)
|
||||
@ -56,7 +56,7 @@ func Benchmark_GetBytes(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_ImmutableString(t *testing.T) {
|
||||
func Test_ImmutableString(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ImmutableString("Hello, World!")
|
||||
AssertEqual(t, "Hello, World!", res)
|
||||
|
33
utils/deprecated.go
Normal file
33
utils/deprecated.go
Normal file
@ -0,0 +1,33 @@
|
||||
package utils
|
||||
|
||||
// #nosec G103
|
||||
// DEPRECATED, Please use UnsafeString instead
|
||||
func GetString(b []byte) string {
|
||||
return UnsafeString(b)
|
||||
}
|
||||
|
||||
// #nosec G103
|
||||
// DEPRECATED, Please use UnsafeBytes instead
|
||||
func GetBytes(s string) []byte {
|
||||
return UnsafeBytes(s)
|
||||
}
|
||||
|
||||
// DEPRECATED, Please use CopyString instead
|
||||
func ImmutableString(s string) string {
|
||||
return CopyString(s)
|
||||
}
|
||||
|
||||
// DEPRECATED, please use EqualFoldBytes
|
||||
func EqualsFold(b, s []byte) (equals bool) {
|
||||
return EqualFoldBytes(b, s)
|
||||
}
|
||||
|
||||
// DEPRECATED, Please use CopyString instead
|
||||
func SafeString(s string) string {
|
||||
return CopyString(s)
|
||||
}
|
||||
|
||||
// DEPRECATED, Please use CopyBytes instead
|
||||
func SafeBytes(b []byte) []byte {
|
||||
return CopyBytes(b)
|
||||
}
|
@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_GetMIME(t *testing.T) {
|
||||
func Test_GetMIME(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := GetMIME(".json")
|
||||
AssertEqual(t, "application/json", res)
|
||||
@ -53,7 +53,7 @@ func Benchmark_GetMIME(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_StatusMessage(t *testing.T) {
|
||||
func Test_StatusMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := StatusMessage(204)
|
||||
AssertEqual(t, "No Content", res)
|
||||
|
@ -60,3 +60,17 @@ func TrimRight(s string, cutset byte) string {
|
||||
}
|
||||
return s[:lenStr]
|
||||
}
|
||||
|
||||
// EqualFold the equivalent of strings.EqualFold
|
||||
func EqualFold(b, s string) (equals bool) {
|
||||
n := len(b)
|
||||
equals = n == len(s)
|
||||
if equals {
|
||||
for i := 0; i < n; i++ {
|
||||
if equals = b[i]|0x20 == s[i]|0x20; !equals {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_Utils_ToUpper(t *testing.T) {
|
||||
func Test_ToUpper(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToUpper("/my/name/is/:param/*")
|
||||
AssertEqual(t, "/MY/NAME/IS/:PARAM/*", res)
|
||||
@ -33,7 +33,7 @@ func Benchmark_ToUpper(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_ToLower(t *testing.T) {
|
||||
func Test_ToLower(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := ToLower("/MY/NAME/IS/:PARAM/*")
|
||||
AssertEqual(t, "/my/name/is/:param/*", res)
|
||||
@ -64,7 +64,7 @@ func Benchmark_ToLower(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimRight(t *testing.T) {
|
||||
func Test_TrimRight(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimRight("/test//////", '/')
|
||||
AssertEqual(t, "/test", res)
|
||||
@ -89,7 +89,7 @@ func Benchmark_TrimRight(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Utils_TrimLeft(t *testing.T) {
|
||||
func Test_TrimLeft(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := TrimLeft("////test/", '/')
|
||||
AssertEqual(t, "test/", res)
|
||||
@ -113,7 +113,7 @@ func Benchmark_TrimLeft(b *testing.B) {
|
||||
AssertEqual(b, "foobar", res)
|
||||
})
|
||||
}
|
||||
func Test_Utils_Trim(t *testing.T) {
|
||||
func Test_Trim(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := Trim(" test ", ' ')
|
||||
AssertEqual(t, "test", res)
|
||||
@ -147,3 +147,39 @@ func Benchmark_Trim(b *testing.B) {
|
||||
AssertEqual(b, "foobar", res)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_EqualFold -benchmem -count=4
|
||||
func Benchmark_EqualFold(b *testing.B) {
|
||||
var left = "/RePos/GoFiBer/FibEr/iSsues/187643/CoMmEnts"
|
||||
var right = "/RePos/goFiber/Fiber/issues/187643/COMMENTS"
|
||||
var res bool
|
||||
|
||||
b.Run("fiber", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = EqualFold(left, right)
|
||||
}
|
||||
AssertEqual(b, true, res)
|
||||
})
|
||||
b.Run("default", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
res = strings.EqualFold(left, right)
|
||||
}
|
||||
AssertEqual(b, true, res)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_EqualFold(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := EqualFold("/MY/NAME/IS/:PARAM/*", "/my/name/is/:param/*")
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualFold("/MY1/NAME/IS/:PARAM/*", "/MY1/NAME/IS/:PARAM/*")
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualFold("/my2/name/is/:param/*", "/my2/name")
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualFold("/dddddd", "eeeeee")
|
||||
AssertEqual(t, false, res)
|
||||
res = EqualFold("/MY3/NAME/IS/:PARAM/*", "/my3/name/is/:param/*")
|
||||
AssertEqual(t, true, res)
|
||||
res = EqualFold("/MY4/NAME/IS/:PARAM/*", "/my4/nAME/IS/:param/*")
|
||||
AssertEqual(t, true, res)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user