mirror of https://github.com/gofiber/fiber.git
160 lines
4.0 KiB
Go
160 lines
4.0 KiB
Go
package idempotency
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/fiber/v3/log"
|
|
"github.com/gofiber/utils/v2"
|
|
)
|
|
|
|
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
|
|
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
|
|
|
|
// The contextKey type is unexported to prevent collisions with context keys defined in
|
|
// other packages.
|
|
type contextKey int
|
|
|
|
const (
|
|
localsKeyIsFromCache contextKey = iota //
|
|
localsKeyWasPutToCache
|
|
)
|
|
|
|
func IsFromCache(c fiber.Ctx) bool {
|
|
return c.Locals(localsKeyIsFromCache) != nil
|
|
}
|
|
|
|
func WasPutToCache(c fiber.Ctx) bool {
|
|
return c.Locals(localsKeyWasPutToCache) != nil
|
|
}
|
|
|
|
func New(config ...Config) fiber.Handler {
|
|
// Set default config
|
|
cfg := configDefault(config...)
|
|
|
|
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
|
|
for _, h := range cfg.KeepResponseHeaders {
|
|
keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
|
|
}
|
|
|
|
maybeWriteCachedResponse := func(c fiber.Ctx, key string) (bool, error) {
|
|
if val, err := cfg.Storage.Get(key); err != nil {
|
|
return false, fmt.Errorf("failed to read response: %w", err)
|
|
} else if val != nil {
|
|
var res response
|
|
if _, err := res.UnmarshalMsg(val); err != nil {
|
|
return false, fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
_ = c.Status(res.StatusCode)
|
|
|
|
for header, vals := range res.Headers {
|
|
for _, val := range vals {
|
|
c.RequestCtx().Response.Header.Add(header, val)
|
|
}
|
|
}
|
|
|
|
if len(res.Body) != 0 {
|
|
if err := c.Send(res.Body); err != nil {
|
|
return true, err
|
|
}
|
|
}
|
|
|
|
_ = c.Locals(localsKeyIsFromCache, true)
|
|
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
return func(c fiber.Ctx) error {
|
|
// Don't execute middleware if Next returns true
|
|
if cfg.Next != nil && cfg.Next(c) {
|
|
return c.Next()
|
|
}
|
|
|
|
// Don't execute middleware if the idempotency key is empty
|
|
key := utils.CopyString(c.Get(cfg.KeyHeader))
|
|
if key == "" {
|
|
return c.Next()
|
|
}
|
|
|
|
// Validate key
|
|
if err := cfg.KeyHeaderValidate(key); err != nil {
|
|
return err
|
|
}
|
|
|
|
// First-pass: if the idempotency key is in the storage, get and return the response
|
|
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
|
|
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
|
|
} else if ok {
|
|
return nil
|
|
}
|
|
|
|
if err := cfg.Lock.Lock(key); err != nil {
|
|
return fmt.Errorf("failed to lock: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := cfg.Lock.Unlock(key); err != nil {
|
|
log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", key, err)
|
|
}
|
|
}()
|
|
|
|
// Lock acquired. If the idempotency key now is in the storage, get and return the response
|
|
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
|
|
return fmt.Errorf("failed to write cached response while locked: %w", err)
|
|
} else if ok {
|
|
return nil
|
|
}
|
|
|
|
// Execute the request handler
|
|
if err := c.Next(); err != nil {
|
|
// If the request handler returned an error, return it and skip idempotency
|
|
return err
|
|
}
|
|
|
|
// Construct response
|
|
res := &response{
|
|
StatusCode: c.Response().StatusCode(),
|
|
|
|
Body: utils.CopyBytes(c.Response().Body()),
|
|
}
|
|
{
|
|
headers := make(map[string][]string)
|
|
if err := c.Bind().RespHeader(headers); err != nil {
|
|
return fmt.Errorf("failed to bind to response headers: %w", err)
|
|
}
|
|
|
|
if cfg.KeepResponseHeaders == nil {
|
|
// Keep all
|
|
res.Headers = headers
|
|
} else {
|
|
// Filter
|
|
res.Headers = make(map[string][]string)
|
|
for h := range headers {
|
|
if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
|
|
res.Headers[h] = headers[h]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Marshal response
|
|
bs, err := res.MarshalMsg(nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal response: %w", err)
|
|
}
|
|
|
|
// Store response
|
|
if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
|
|
return fmt.Errorf("failed to save response: %w", err)
|
|
}
|
|
|
|
_ = c.Locals(localsKeyWasPutToCache, true)
|
|
|
|
return nil
|
|
}
|
|
}
|