mirror of https://github.com/gofiber/fiber.git
258 lines
6.6 KiB
Go
258 lines
6.6 KiB
Go
// Special thanks to @codemicro for moving this to fiber core
|
|
// Original middleware: github.com/codemicro/fiber-cache
|
|
package cache
|
|
|
|
import (
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/gofiber/utils/v2"
|
|
)
|
|
|
|
// timestampUpdatePeriod is the period which is used to check the cache expiration.
|
|
// It should not be too long to provide more or less acceptable expiration error, and in the same
|
|
// time it should not be too short to avoid overwhelming of the system
|
|
const timestampUpdatePeriod = 300 * time.Millisecond
|
|
|
|
// cache status
|
|
// unreachable: when cache is bypass, or invalid
|
|
// hit: cache is served
|
|
// miss: do not have cache record
|
|
const (
|
|
cacheUnreachable = "unreachable"
|
|
cacheHit = "hit"
|
|
cacheMiss = "miss"
|
|
)
|
|
|
|
// directives
|
|
const (
|
|
noCache = "no-cache"
|
|
noStore = "no-store"
|
|
)
|
|
|
|
var ignoreHeaders = map[string]any{
|
|
"Connection": nil,
|
|
"Keep-Alive": nil,
|
|
"Proxy-Authenticate": nil,
|
|
"Proxy-Authorization": nil,
|
|
"TE": nil,
|
|
"Trailers": nil,
|
|
"Transfer-Encoding": nil,
|
|
"Upgrade": nil,
|
|
"Content-Type": nil, // already stored explicitly by the cache manager
|
|
"Content-Encoding": nil, // already stored explicitly by the cache manager
|
|
}
|
|
|
|
// New creates a new middleware handler
|
|
func New(config ...Config) fiber.Handler {
|
|
// Set default config
|
|
cfg := configDefault(config...)
|
|
|
|
// Nothing to cache
|
|
if int(cfg.Expiration.Seconds()) < 0 {
|
|
return func(c fiber.Ctx) error {
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
var (
|
|
// Cache settings
|
|
mux = &sync.RWMutex{}
|
|
timestamp = uint64(time.Now().Unix()) //nolint:gosec //Not a concern
|
|
)
|
|
// Create manager to simplify storage operations ( see manager.go )
|
|
manager := newManager(cfg.Storage)
|
|
// Create indexed heap for tracking expirations ( see heap.go )
|
|
heap := &indexedHeap{}
|
|
// count stored bytes (sizes of response bodies)
|
|
var storedBytes uint
|
|
|
|
// Update timestamp in the configured interval
|
|
go func() {
|
|
for {
|
|
atomic.StoreUint64(×tamp, uint64(time.Now().Unix())) //nolint:gosec //Not a concern
|
|
time.Sleep(timestampUpdatePeriod)
|
|
}
|
|
}()
|
|
|
|
// Delete key from both manager and storage
|
|
deleteKey := func(dkey string) {
|
|
manager.del(dkey)
|
|
// External storage saves body data with different key
|
|
if cfg.Storage != nil {
|
|
manager.del(dkey + "_body")
|
|
}
|
|
}
|
|
|
|
// Return new handler
|
|
return func(c fiber.Ctx) error {
|
|
// Refrain from caching
|
|
if hasRequestDirective(c, noStore) {
|
|
return c.Next()
|
|
}
|
|
|
|
requestMethod := c.Method()
|
|
|
|
// Only cache selected methods
|
|
if !slices.Contains(cfg.Methods, requestMethod) {
|
|
c.Set(cfg.CacheHeader, cacheUnreachable)
|
|
return c.Next()
|
|
}
|
|
|
|
// Get key from request
|
|
// TODO(allocation optimization): try to minimize the allocation from 2 to 1
|
|
key := cfg.KeyGenerator(c) + "_" + requestMethod
|
|
|
|
// Get entry from pool
|
|
e := manager.get(key)
|
|
|
|
// Lock entry
|
|
mux.Lock()
|
|
|
|
// Get timestamp
|
|
ts := atomic.LoadUint64(×tamp)
|
|
|
|
// Cache Entry found
|
|
if e != nil {
|
|
// Invalidate cache if requested
|
|
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) {
|
|
e.exp = ts - 1
|
|
}
|
|
|
|
// Check if entry is expired
|
|
if e.exp != 0 && ts >= e.exp {
|
|
deleteKey(key)
|
|
if cfg.MaxBytes > 0 {
|
|
_, size := heap.remove(e.heapidx)
|
|
storedBytes -= size
|
|
}
|
|
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
|
|
// Separate body value to avoid msgp serialization
|
|
// We can store raw bytes with Storage 👍
|
|
if cfg.Storage != nil {
|
|
e.body = manager.getRaw(key + "_body")
|
|
}
|
|
// Set response headers from cache
|
|
c.Response().SetBodyRaw(e.body)
|
|
c.Response().SetStatusCode(e.status)
|
|
c.Response().Header.SetContentTypeBytes(e.ctype)
|
|
if len(e.cencoding) > 0 {
|
|
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
|
|
}
|
|
for k, v := range e.headers {
|
|
c.Response().Header.SetBytesV(k, v)
|
|
}
|
|
// Set Cache-Control header if enabled
|
|
if cfg.CacheControl {
|
|
maxAge := strconv.FormatUint(e.exp-ts, 10)
|
|
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
|
|
}
|
|
|
|
c.Set(cfg.CacheHeader, cacheHit)
|
|
|
|
mux.Unlock()
|
|
|
|
// Return response
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// make sure we're not blocking concurrent requests - do unlock
|
|
mux.Unlock()
|
|
|
|
// Continue stack, return err to Fiber if exist
|
|
if err := c.Next(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// lock entry back and unlock on finish
|
|
mux.Lock()
|
|
defer mux.Unlock()
|
|
|
|
// Don't cache response if Next returns true
|
|
if cfg.Next != nil && cfg.Next(c) {
|
|
c.Set(cfg.CacheHeader, cacheUnreachable)
|
|
return nil
|
|
}
|
|
|
|
// Don't try to cache if body won't fit into cache
|
|
bodySize := uint(len(c.Response().Body()))
|
|
if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes {
|
|
c.Set(cfg.CacheHeader, cacheUnreachable)
|
|
return nil
|
|
}
|
|
|
|
// Remove oldest to make room for new
|
|
if cfg.MaxBytes > 0 {
|
|
for storedBytes+bodySize > cfg.MaxBytes {
|
|
key, size := heap.removeFirst()
|
|
deleteKey(key)
|
|
storedBytes -= size
|
|
}
|
|
}
|
|
|
|
e = manager.acquire()
|
|
// Cache response
|
|
e.body = utils.CopyBytes(c.Response().Body())
|
|
e.status = c.Response().StatusCode()
|
|
e.ctype = utils.CopyBytes(c.Response().Header.ContentType())
|
|
e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
|
|
|
|
// Store all response headers
|
|
// (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1)
|
|
if cfg.StoreResponseHeaders {
|
|
e.headers = make(map[string][]byte)
|
|
c.Response().Header.VisitAll(
|
|
func(key, value []byte) {
|
|
// create real copy
|
|
keyS := string(key)
|
|
if _, ok := ignoreHeaders[keyS]; !ok {
|
|
e.headers[keyS] = utils.CopyBytes(value)
|
|
}
|
|
},
|
|
)
|
|
}
|
|
|
|
// default cache expiration
|
|
expiration := cfg.Expiration
|
|
// Calculate expiration by response header or other setting
|
|
if cfg.ExpirationGenerator != nil {
|
|
expiration = cfg.ExpirationGenerator(c, &cfg)
|
|
}
|
|
e.exp = ts + uint64(expiration.Seconds())
|
|
|
|
// Store entry in heap
|
|
if cfg.MaxBytes > 0 {
|
|
e.heapidx = heap.put(key, e.exp, bodySize)
|
|
storedBytes += bodySize
|
|
}
|
|
|
|
// For external Storage we store raw body separated
|
|
if cfg.Storage != nil {
|
|
manager.setRaw(key+"_body", e.body, expiration)
|
|
// avoid body msgp encoding
|
|
e.body = nil
|
|
manager.set(key, e, expiration)
|
|
manager.release(e)
|
|
} else {
|
|
// Store entry in memory
|
|
manager.set(key, e, expiration)
|
|
}
|
|
|
|
c.Set(cfg.CacheHeader, cacheMiss)
|
|
|
|
// Finish response
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Check if request has directive
|
|
func hasRequestDirective(c fiber.Ctx, directive string) bool {
|
|
return strings.Contains(c.Get(fiber.HeaderCacheControl), directive)
|
|
}
|