From 37980937b5143d8c5295a667295e3517d533c867 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enver=20Bi=C5=A1evac?= Date: Tue, 10 Dec 2024 00:28:43 +0000 Subject: [PATCH] feat: [code-2911] data race panic when cache counters are modified (#3130) * fix data race panic when cache counters are modified --- cache/redis_cache.go | 13 ++++++------- cache/ttl_cache.go | 11 ++++++----- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cache/redis_cache.go b/cache/redis_cache.go index d91914fdf..3c37f8dd5 100644 --- a/cache/redis_cache.go +++ b/cache/redis_cache.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" "github.com/go-redis/redis/v8" @@ -31,8 +32,8 @@ type Redis[K any, V any] struct { getter Getter[K, V] keyEncoder func(K) string codec Codec[V] - countHit int64 - countMiss int64 + countHit atomic.Int64 + countMiss atomic.Int64 logErrFn LogErrFn } @@ -63,15 +64,13 @@ func NewRedis[K any, V any]( getter: getter, keyEncoder: keyEncoder, codec: codec, - countHit: 0, - countMiss: 0, logErrFn: logErrFn, } } // Stats returns number of cache hits and misses and can be used to monitor the cache efficiency. func (c *Redis[K, V]) Stats() (int64, int64) { - return c.countHit, c.countMiss + return c.countHit.Load(), c.countMiss.Load() } // Get implements the cache.Cache interface. @@ -84,14 +83,14 @@ func (c *Redis[K, V]) Get(ctx context.Context, key K) (V, error) { if err == nil { value, decErr := c.codec.Decode(raw) if decErr == nil { - c.countHit++ + c.countHit.Add(1) return value, nil } } else if !errors.Is(err, redis.Nil) && c.logErrFn != nil { c.logErrFn(ctx, err) } - c.countMiss++ + c.countMiss.Add(1) item, err := c.getter.Find(ctx, key) if err != nil { diff --git a/cache/ttl_cache.go b/cache/ttl_cache.go index 6c82f8102..3c8970f6b 100644 --- a/cache/ttl_cache.go +++ b/cache/ttl_cache.go @@ -19,6 +19,7 @@ import ( "fmt" "sort" "sync" + "sync/atomic" "time" "golang.org/x/exp/constraints" @@ -34,8 +35,8 @@ type TTLCache[K comparable, V any] struct { purgeStop chan struct{} getter Getter[K, V] maxAge time.Duration - countHit int64 - countMiss int64 + countHit atomic.Int64 + countMiss atomic.Int64 } // ExtendedTTLCache is an extended version of the TTLCache. @@ -113,7 +114,7 @@ func (c *TTLCache[K, V]) Stop() { // Stats returns number of cache hits and misses and can be used to monitor the cache efficiency. func (c *TTLCache[K, V]) Stats() (int64, int64) { - return c.countHit, c.countMiss + return c.countHit.Load(), c.countMiss.Load() } func (c *TTLCache[K, V]) fetch(key K, now time.Time) (V, bool) { @@ -122,12 +123,12 @@ func (c *TTLCache[K, V]) fetch(key K, now time.Time) (V, bool) { item, ok := c.cache[key] if !ok || now.Sub(item.added) > c.maxAge { - c.countMiss++ + c.countMiss.Add(1) var nothing V return nothing, false } - c.countHit++ + c.countHit.Add(1) // we deliberately don't update the `item.added` timestamp for `now` because // we want to cache the items only for a short period.