mirror of https://github.com/gofiber/fiber.git
🚀 Feature: Add idempotency middleware (v2 backport) (#2288)
* 🚀 Feature: Add idempotency middleware (#2253) * middleware: add idempotency middleware * middleware/idempotency: use fiber.Storage instead of custom storage * middleware/idempotency: only allocate data if really required * middleware/idempotency: marshal response using msgp * middleware/idempotency: add msgp tests * middleware/idempotency: do not export response * middleware/idempotency: disable msgp's -io option to disable generating unused methods * middleware/idempotency: switch to time.Duration based app.Test * middleware/idempotency: only create closure once * middleware/idempotency: add benchmarks * middleware/idempotency: optimize strings.ToLower when making comparison The real "strings.ToLower" still needs to be used when storing the data. * middleware/idempotency: safe-copy body * middleware/idempotency: backport to v2pull/2296/head
parent
9c5dfdbe5d
commit
adcf92dec1
30
helpers.go
30
helpers.go
|
@ -369,6 +369,36 @@ func (app *App) methodInt(s string) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsMethodSafe reports whether the HTTP method is considered safe.
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc9110#section-9.2.1
|
||||||
|
func IsMethodSafe(m string) bool {
|
||||||
|
switch m {
|
||||||
|
case MethodGet,
|
||||||
|
MethodHead,
|
||||||
|
MethodOptions,
|
||||||
|
MethodTrace:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMethodIdempotent reports whether the HTTP method is considered idempotent.
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc9110#section-9.2.2
|
||||||
|
func IsMethodIdempotent(m string) bool {
|
||||||
|
if IsMethodSafe(m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m {
|
||||||
|
case MethodPut,
|
||||||
|
MethodDelete:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// HTTP methods were copied from net/http.
|
// HTTP methods were copied from net/http.
|
||||||
const (
|
const (
|
||||||
MethodGet = "GET" // RFC 7231, 4.3.1
|
MethodGet = "GET" // RFC 7231, 4.3.1
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Idempotency Middleware
|
||||||
|
|
||||||
|
Idempotency middleware for [Fiber](https://github.com/gofiber/fiber) allows for fault-tolerant APIs where duplicate requests — for example due to networking issues on the client-side — do not erroneously cause the same action performed multiple times on the server-side.
|
||||||
|
|
||||||
|
Refer to https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02 for a better understanding.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Idempotency Middleware](#idempotency-middleware)
|
||||||
|
- [Table of Contents](#table-of-contents)
|
||||||
|
- [Signatures](#signatures)
|
||||||
|
- [Examples](#examples)
|
||||||
|
- [Default Config](#default-config)
|
||||||
|
- [Custom Config](#custom-config)
|
||||||
|
- [Config](#config)
|
||||||
|
- [Default Config](#default-config-1)
|
||||||
|
|
||||||
|
## Signatures
|
||||||
|
|
||||||
|
```go
|
||||||
|
func New(config ...Config) fiber.Handler
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
First import the middleware from Fiber,
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then create a Fiber app with `app := fiber.New()`.
|
||||||
|
|
||||||
|
### Default Config
|
||||||
|
|
||||||
|
```go
|
||||||
|
app.Use(idempotency.New())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Config
|
||||||
|
|
||||||
|
```go
|
||||||
|
app.Use(idempotency.New(idempotency.Config{
|
||||||
|
Lifetime: 42 * time.Minute,
|
||||||
|
// ...
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Config struct {
|
||||||
|
// Next defines a function to skip this middleware when returned true.
|
||||||
|
//
|
||||||
|
// Optional. Default: a function which skips the middleware on safe HTTP request method.
|
||||||
|
Next func(c *fiber.Ctx) bool
|
||||||
|
|
||||||
|
// Lifetime is the maximum lifetime of an idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: 30 * time.Minute
|
||||||
|
Lifetime time.Duration
|
||||||
|
|
||||||
|
// KeyHeader is the name of the header that contains the idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: X-Idempotency-Key
|
||||||
|
KeyHeader string
|
||||||
|
// KeyHeaderValidate defines a function to validate the syntax of the idempotency header.
|
||||||
|
//
|
||||||
|
// Optional. Default: a function which ensures the header is 36 characters long (the size of an UUID).
|
||||||
|
KeyHeaderValidate func(string) error
|
||||||
|
|
||||||
|
// KeepResponseHeaders is a list of headers that should be kept from the original response.
|
||||||
|
//
|
||||||
|
// Optional. Default: nil (to keep all headers)
|
||||||
|
KeepResponseHeaders []string
|
||||||
|
|
||||||
|
// Lock locks an idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: an in-memory locker for this process only.
|
||||||
|
Lock Locker
|
||||||
|
|
||||||
|
// Storage stores response data by idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: an in-memory storage for this process only.
|
||||||
|
Storage fiber.Storage
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Default Config
|
||||||
|
|
||||||
|
```go
|
||||||
|
var ConfigDefault = Config{
|
||||||
|
Next: func(c *fiber.Ctx) bool {
|
||||||
|
// Skip middleware if the request was done using a safe HTTP method
|
||||||
|
return fiber.IsMethodSafe(c.Method())
|
||||||
|
},
|
||||||
|
|
||||||
|
Lifetime: 30 * time.Minute,
|
||||||
|
|
||||||
|
KeyHeader: "X-Idempotency-Key",
|
||||||
|
KeyHeaderValidate: func(k string) error {
|
||||||
|
if l, wl := len(k), 36; l != wl { // UUID length is 36 chars
|
||||||
|
return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
|
||||||
|
KeepResponseHeaders: nil,
|
||||||
|
|
||||||
|
Lock: nil, // Set in configDefault so we don't allocate data here.
|
||||||
|
|
||||||
|
Storage: nil, // Set in configDefault so we don't allocate data here.
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,120 @@
|
||||||
|
package idempotency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config defines the config for middleware.
|
||||||
|
type Config struct {
|
||||||
|
// Next defines a function to skip this middleware when returned true.
|
||||||
|
//
|
||||||
|
// Optional. Default: a function which skips the middleware on safe HTTP request method.
|
||||||
|
Next func(c *fiber.Ctx) bool
|
||||||
|
|
||||||
|
// Lifetime is the maximum lifetime of an idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: 30 * time.Minute
|
||||||
|
Lifetime time.Duration
|
||||||
|
|
||||||
|
// KeyHeader is the name of the header that contains the idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: X-Idempotency-Key
|
||||||
|
KeyHeader string
|
||||||
|
// KeyHeaderValidate defines a function to validate the syntax of the idempotency header.
|
||||||
|
//
|
||||||
|
// Optional. Default: a function which ensures the header is 36 characters long (the size of an UUID).
|
||||||
|
KeyHeaderValidate func(string) error
|
||||||
|
|
||||||
|
// KeepResponseHeaders is a list of headers that should be kept from the original response.
|
||||||
|
//
|
||||||
|
// Optional. Default: nil (to keep all headers)
|
||||||
|
KeepResponseHeaders []string
|
||||||
|
|
||||||
|
// Lock locks an idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: an in-memory locker for this process only.
|
||||||
|
Lock Locker
|
||||||
|
|
||||||
|
// Storage stores response data by idempotency key.
|
||||||
|
//
|
||||||
|
// Optional. Default: an in-memory storage for this process only.
|
||||||
|
Storage fiber.Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigDefault is the default config
|
||||||
|
var ConfigDefault = Config{
|
||||||
|
Next: func(c *fiber.Ctx) bool {
|
||||||
|
// Skip middleware if the request was done using a safe HTTP method
|
||||||
|
return fiber.IsMethodSafe(c.Method())
|
||||||
|
},
|
||||||
|
|
||||||
|
Lifetime: 30 * time.Minute,
|
||||||
|
|
||||||
|
KeyHeader: "X-Idempotency-Key",
|
||||||
|
KeyHeaderValidate: func(k string) error {
|
||||||
|
if l, wl := len(k), 36; l != wl { // UUID length is 36 chars
|
||||||
|
return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
|
||||||
|
KeepResponseHeaders: nil,
|
||||||
|
|
||||||
|
Lock: nil, // Set in configDefault so we don't allocate data here.
|
||||||
|
|
||||||
|
Storage: nil, // Set in configDefault so we don't allocate data here.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to set default values
|
||||||
|
func configDefault(config ...Config) Config {
|
||||||
|
// Return default config if nothing provided
|
||||||
|
if len(config) < 1 {
|
||||||
|
return ConfigDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override default config
|
||||||
|
cfg := config[0]
|
||||||
|
|
||||||
|
// Set default values
|
||||||
|
|
||||||
|
if cfg.Next == nil {
|
||||||
|
cfg.Next = ConfigDefault.Next
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Lifetime.Nanoseconds() == 0 {
|
||||||
|
cfg.Lifetime = ConfigDefault.Lifetime
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.KeyHeader == "" {
|
||||||
|
cfg.KeyHeader = ConfigDefault.KeyHeader
|
||||||
|
}
|
||||||
|
if cfg.KeyHeaderValidate == nil {
|
||||||
|
cfg.KeyHeaderValidate = ConfigDefault.KeyHeaderValidate
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.KeepResponseHeaders != nil && len(cfg.KeepResponseHeaders) == 0 {
|
||||||
|
cfg.KeepResponseHeaders = ConfigDefault.KeepResponseHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Lock == nil {
|
||||||
|
cfg.Lock = NewMemoryLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Storage == nil {
|
||||||
|
cfg.Storage = memory.New(memory.Config{
|
||||||
|
GCInterval: cfg.Lifetime / 2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
package idempotency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
|
||||||
|
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
|
||||||
|
|
||||||
|
const (
|
||||||
|
localsKeyIsFromCache = "idempotency_isfromcache"
|
||||||
|
localsKeyWasPutToCache = "idempotency_wasputtocache"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsFromCache(c *fiber.Ctx) bool {
|
||||||
|
return c.Locals(localsKeyIsFromCache) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WasPutToCache(c *fiber.Ctx) bool {
|
||||||
|
return c.Locals(localsKeyWasPutToCache) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(config ...Config) fiber.Handler {
|
||||||
|
// Set default config
|
||||||
|
cfg := configDefault(config...)
|
||||||
|
|
||||||
|
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
|
||||||
|
for _, h := range cfg.KeepResponseHeaders {
|
||||||
|
keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
maybeWriteCachedResponse := func(c *fiber.Ctx, key string) (bool, error) {
|
||||||
|
if val, err := cfg.Storage.Get(key); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
} else if val != nil {
|
||||||
|
var res response
|
||||||
|
if _, err := res.UnmarshalMsg(val); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = c.Status(res.StatusCode)
|
||||||
|
|
||||||
|
for header, val := range res.Headers {
|
||||||
|
c.Set(header, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Body) != 0 {
|
||||||
|
if err := c.Send(res.Body); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = c.Locals(localsKeyIsFromCache, true)
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
// Don't execute middleware if Next returns true
|
||||||
|
if cfg.Next != nil && cfg.Next(c) {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't execute middleware if the idempotency key is empty
|
||||||
|
key := utils.CopyString(c.Get(cfg.KeyHeader))
|
||||||
|
if key == "" {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key
|
||||||
|
if err := cfg.KeyHeaderValidate(key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// First-pass: if the idempotency key is in the storage, get and return the response
|
||||||
|
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
|
||||||
|
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
|
||||||
|
} else if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Lock.Lock(key); err != nil {
|
||||||
|
return fmt.Errorf("failed to lock: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := cfg.Lock.Unlock(key); err != nil {
|
||||||
|
log.Printf("middleware/idempotency: failed to unlock key %q: %v", key, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Lock acquired. If the idempotency key now is in the storage, get and return the response
|
||||||
|
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
|
||||||
|
return fmt.Errorf("failed to write cached response while locked: %w", err)
|
||||||
|
} else if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the request handler
|
||||||
|
if err := c.Next(); err != nil {
|
||||||
|
// If the request handler returned an error, return it and skip idempotency
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct response
|
||||||
|
res := &response{
|
||||||
|
StatusCode: c.Response().StatusCode(),
|
||||||
|
|
||||||
|
Body: utils.CopyBytes(c.Response().Body()),
|
||||||
|
}
|
||||||
|
{
|
||||||
|
headers := c.GetRespHeaders()
|
||||||
|
if cfg.KeepResponseHeaders == nil {
|
||||||
|
// Keep all
|
||||||
|
res.Headers = headers
|
||||||
|
} else {
|
||||||
|
// Filter
|
||||||
|
res.Headers = make(map[string]string)
|
||||||
|
for h := range headers {
|
||||||
|
if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
|
||||||
|
res.Headers[h] = headers[h]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal response
|
||||||
|
bs, err := res.MarshalMsg(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store response
|
||||||
|
if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
|
||||||
|
return fmt.Errorf("failed to save response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = c.Locals(localsKeyWasPutToCache, true)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,176 @@
|
||||||
|
package idempotency_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||||
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// go test -run Test_Idempotency
|
||||||
|
func Test_Idempotency(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Use(func(c *fiber.Ctx) error {
|
||||||
|
if err := c.Next(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isMethodSafe := fiber.IsMethodSafe(c.Method())
|
||||||
|
isIdempotent := idempotency.IsFromCache(c) || idempotency.WasPutToCache(c)
|
||||||
|
hasReqHeader := c.Get("X-Idempotency-Key") != ""
|
||||||
|
|
||||||
|
if isMethodSafe {
|
||||||
|
if isIdempotent {
|
||||||
|
return errors.New("request with safe HTTP method should not be idempotent")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unsafe
|
||||||
|
if hasReqHeader {
|
||||||
|
if !isIdempotent {
|
||||||
|
return errors.New("request with unsafe HTTP method should be idempotent if X-Idempotency-Key request header is set")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No request header
|
||||||
|
if isIdempotent {
|
||||||
|
return errors.New("request with unsafe HTTP method should not be idempotent if X-Idempotency-Key request header is not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Needs to be at least a second as the memory storage doesn't support shorter durations.
|
||||||
|
const lifetime = 1 * time.Second
|
||||||
|
|
||||||
|
app.Use(idempotency.New(idempotency.Config{
|
||||||
|
Lifetime: lifetime,
|
||||||
|
}))
|
||||||
|
|
||||||
|
nextCount := func() func() int {
|
||||||
|
var count int32
|
||||||
|
return func() int {
|
||||||
|
return int(atomic.AddInt32(&count, 1))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
{
|
||||||
|
handler := func(c *fiber.Ctx) error {
|
||||||
|
return c.SendString(strconv.Itoa(nextCount()))
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Get("/", handler)
|
||||||
|
app.Post("/", handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Post("/slow", func(c *fiber.Ctx) error {
|
||||||
|
time.Sleep(2 * lifetime)
|
||||||
|
|
||||||
|
return c.SendString(strconv.Itoa(nextCount()))
|
||||||
|
})
|
||||||
|
|
||||||
|
doReq := func(method, route, idempotencyKey string) string {
|
||||||
|
req := httptest.NewRequest(method, route, http.NoBody)
|
||||||
|
if idempotencyKey != "" {
|
||||||
|
req.Header.Set("X-Idempotency-Key", idempotencyKey)
|
||||||
|
}
|
||||||
|
resp, err := app.Test(req, 3*int(lifetime.Milliseconds()))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, string(body))
|
||||||
|
return string(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.AssertEqual(t, "1", doReq(fiber.MethodGet, "/", ""))
|
||||||
|
utils.AssertEqual(t, "2", doReq(fiber.MethodGet, "/", ""))
|
||||||
|
|
||||||
|
utils.AssertEqual(t, "3", doReq(fiber.MethodPost, "/", ""))
|
||||||
|
utils.AssertEqual(t, "4", doReq(fiber.MethodPost, "/", ""))
|
||||||
|
|
||||||
|
utils.AssertEqual(t, "5", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
utils.AssertEqual(t, "6", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
|
||||||
|
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
utils.AssertEqual(t, "8", doReq(fiber.MethodPost, "/", ""))
|
||||||
|
utils.AssertEqual(t, "9", doReq(fiber.MethodPost, "/", "11111111-1111-1111-1111-111111111111"))
|
||||||
|
|
||||||
|
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
time.Sleep(2 * lifetime)
|
||||||
|
utils.AssertEqual(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
utils.AssertEqual(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||||
|
|
||||||
|
// Test raciness
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
utils.AssertEqual(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
utils.AssertEqual(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
|
||||||
|
}
|
||||||
|
time.Sleep(2 * lifetime)
|
||||||
|
utils.AssertEqual(t, "12", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// go test -v -run=^$ -bench=Benchmark_Idempotency -benchmem -count=4
|
||||||
|
func Benchmark_Idempotency(b *testing.B) {
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
// Needs to be at least a second as the memory storage doesn't support shorter durations.
|
||||||
|
const lifetime = 1 * time.Second
|
||||||
|
|
||||||
|
app.Use(idempotency.New(idempotency.Config{
|
||||||
|
Lifetime: lifetime,
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Post("/", func(c *fiber.Ctx) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
h := app.Handler()
|
||||||
|
|
||||||
|
b.Run("hit", func(b *testing.B) {
|
||||||
|
c := &fasthttp.RequestCtx{}
|
||||||
|
c.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
|
c.Request.SetRequestURI("/")
|
||||||
|
c.Request.Header.Set("X-Idempotency-Key", "00000000-0000-0000-0000-000000000000")
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
h(c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("skip", func(b *testing.B) {
|
||||||
|
c := &fasthttp.RequestCtx{}
|
||||||
|
c.Request.Header.SetMethod(fiber.MethodPost)
|
||||||
|
c.Request.SetRequestURI("/")
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
h(c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
package idempotency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Locker implements a spinlock for a string key.
|
||||||
|
type Locker interface {
|
||||||
|
Lock(key string) error
|
||||||
|
Unlock(key string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type MemoryLock struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
keys map[string]*sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *MemoryLock) Lock(key string) error {
|
||||||
|
l.mu.Lock()
|
||||||
|
mu, ok := l.keys[key]
|
||||||
|
if !ok {
|
||||||
|
mu = new(sync.Mutex)
|
||||||
|
l.keys[key] = mu
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *MemoryLock) Unlock(key string) error {
|
||||||
|
l.mu.Lock()
|
||||||
|
mu, ok := l.keys[key]
|
||||||
|
l.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
// This happens if we try to unlock an unknown key
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMemoryLock() *MemoryLock {
|
||||||
|
return &MemoryLock{
|
||||||
|
keys: make(map[string]*sync.Mutex),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Locker = (*MemoryLock)(nil)
|
|
@ -0,0 +1,59 @@
|
||||||
|
package idempotency_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||||
|
"github.com/gofiber/fiber/v2/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// go test -run Test_MemoryLock
|
||||||
|
func Test_MemoryLock(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
l := idempotency.NewMemoryLock()
|
||||||
|
|
||||||
|
{
|
||||||
|
err := l.Lock("a")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
err := l.Lock("a")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Fatal("lock acquired again")
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
err := l.Lock("b")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
err := l.Unlock("b")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
err := l.Lock("b")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
err := l.Unlock("c")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
err := l.Lock("d")
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
package idempotency
|
||||||
|
|
||||||
|
//go:generate msgp -o=response_msgp.go -io=false -unexported
|
||||||
|
type response struct {
|
||||||
|
StatusCode int `msg:"sc"`
|
||||||
|
|
||||||
|
Headers map[string]string `msg:"hs"`
|
||||||
|
|
||||||
|
Body []byte `msg:"b"`
|
||||||
|
}
|
|
@ -0,0 +1,112 @@
|
||||||
|
package idempotency
|
||||||
|
|
||||||
|
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MarshalMsg implements msgp.Marshaler
|
||||||
|
func (z *response) MarshalMsg(b []byte) (o []byte, err error) {
|
||||||
|
o = msgp.Require(b, z.Msgsize())
|
||||||
|
// map header, size 3
|
||||||
|
// string "sc"
|
||||||
|
o = append(o, 0x83, 0xa2, 0x73, 0x63)
|
||||||
|
o = msgp.AppendInt(o, z.StatusCode)
|
||||||
|
// string "hs"
|
||||||
|
o = append(o, 0xa2, 0x68, 0x73)
|
||||||
|
o = msgp.AppendMapHeader(o, uint32(len(z.Headers)))
|
||||||
|
for za0001, za0002 := range z.Headers {
|
||||||
|
o = msgp.AppendString(o, za0001)
|
||||||
|
o = msgp.AppendString(o, za0002)
|
||||||
|
}
|
||||||
|
// string "b"
|
||||||
|
o = append(o, 0xa1, 0x62)
|
||||||
|
o = msgp.AppendBytes(o, z.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalMsg implements msgp.Unmarshaler
|
||||||
|
func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||||
|
var field []byte
|
||||||
|
_ = field
|
||||||
|
var zb0001 uint32
|
||||||
|
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for zb0001 > 0 {
|
||||||
|
zb0001--
|
||||||
|
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msgp.UnsafeString(field) {
|
||||||
|
case "sc":
|
||||||
|
z.StatusCode, bts, err = msgp.ReadIntBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "StatusCode")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "hs":
|
||||||
|
var zb0002 uint32
|
||||||
|
zb0002, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Headers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if z.Headers == nil {
|
||||||
|
z.Headers = make(map[string]string, zb0002)
|
||||||
|
} else if len(z.Headers) > 0 {
|
||||||
|
for key := range z.Headers {
|
||||||
|
delete(z.Headers, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for zb0002 > 0 {
|
||||||
|
var za0001 string
|
||||||
|
var za0002 string
|
||||||
|
zb0002--
|
||||||
|
za0001, bts, err = msgp.ReadStringBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Headers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
za0002, bts, err = msgp.ReadStringBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Headers", za0001)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
z.Headers[za0001] = za0002
|
||||||
|
}
|
||||||
|
case "b":
|
||||||
|
z.Body, bts, err = msgp.ReadBytesBytes(bts, z.Body)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err, "Body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
bts, err = msgp.Skip(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o = bts
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||||
|
func (z *response) Msgsize() (s int) {
|
||||||
|
s = 1 + 3 + msgp.IntSize + 3 + msgp.MapHeaderSize
|
||||||
|
if z.Headers != nil {
|
||||||
|
for za0001, za0002 := range z.Headers {
|
||||||
|
_ = za0002
|
||||||
|
s += msgp.StringPrefixSize + len(za0001) + msgp.StringPrefixSize + len(za0002)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s += 2 + msgp.BytesPrefixSize + len(z.Body)
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
package idempotency
|
||||||
|
|
||||||
|
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2/internal/msgp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalresponse(t *testing.T) {
|
||||||
|
v := response{}
|
||||||
|
bts, err := v.MarshalMsg(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
left, err := v.UnmarshalMsg(bts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(left) > 0 {
|
||||||
|
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, err = msgp.Skip(bts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(left) > 0 {
|
||||||
|
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMarshalMsgresponse(b *testing.B) {
|
||||||
|
v := response{}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
v.MarshalMsg(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkAppendMsgresponse(b *testing.B) {
|
||||||
|
v := response{}
|
||||||
|
bts := make([]byte, 0, v.Msgsize())
|
||||||
|
bts, _ = v.MarshalMsg(bts[0:0])
|
||||||
|
b.SetBytes(int64(len(bts)))
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bts, _ = v.MarshalMsg(bts[0:0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnmarshalresponse(b *testing.B) {
|
||||||
|
v := response{}
|
||||||
|
bts, _ := v.MarshalMsg(nil)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(bts)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := v.UnmarshalMsg(bts)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue