fiber/middleware/idempotency/idempotency.go

160 lines
4.0 KiB
Go

package idempotency
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
localsKeyIsFromCache contextKey = iota //
localsKeyWasPutToCache
)
func IsFromCache(c fiber.Ctx) bool {
return c.Locals(localsKeyIsFromCache) != nil
}
func WasPutToCache(c fiber.Ctx) bool {
return c.Locals(localsKeyWasPutToCache) != nil
}
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
for _, h := range cfg.KeepResponseHeaders {
keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
}
maybeWriteCachedResponse := func(c fiber.Ctx, key string) (bool, error) {
if val, err := cfg.Storage.Get(key); err != nil {
return false, fmt.Errorf("failed to read response: %w", err)
} else if val != nil {
var res response
if _, err := res.UnmarshalMsg(val); err != nil {
return false, fmt.Errorf("failed to unmarshal response: %w", err)
}
_ = c.Status(res.StatusCode)
for header, vals := range res.Headers {
for _, val := range vals {
c.RequestCtx().Response.Header.Add(header, val)
}
}
if len(res.Body) != 0 {
if err := c.Send(res.Body); err != nil {
return true, err
}
}
_ = c.Locals(localsKeyIsFromCache, true)
return true, nil
}
return false, nil
}
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Don't execute middleware if the idempotency key is empty
key := utils.CopyString(c.Get(cfg.KeyHeader))
if key == "" {
return c.Next()
}
// Validate key
if err := cfg.KeyHeaderValidate(key); err != nil {
return err
}
// First-pass: if the idempotency key is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
} else if ok {
return nil
}
if err := cfg.Lock.Lock(key); err != nil {
return fmt.Errorf("failed to lock: %w", err)
}
defer func() {
if err := cfg.Lock.Unlock(key); err != nil {
log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", key, err)
}
}()
// Lock acquired. If the idempotency key now is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response while locked: %w", err)
} else if ok {
return nil
}
// Execute the request handler
if err := c.Next(); err != nil {
// If the request handler returned an error, return it and skip idempotency
return err
}
// Construct response
res := &response{
StatusCode: c.Response().StatusCode(),
Body: utils.CopyBytes(c.Response().Body()),
}
{
headers := make(map[string][]string)
if err := c.Bind().RespHeader(headers); err != nil {
return fmt.Errorf("failed to bind to response headers: %w", err)
}
if cfg.KeepResponseHeaders == nil {
// Keep all
res.Headers = headers
} else {
// Filter
res.Headers = make(map[string][]string)
for h := range headers {
if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
res.Headers[h] = headers[h]
}
}
}
}
// Marshal response
bs, err := res.MarshalMsg(nil)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
// Store response
if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
return fmt.Errorf("failed to save response: %w", err)
}
_ = c.Locals(localsKeyWasPutToCache, true)
return nil
}
}