diff --git a/internal/ctxwatch/context_watcher.go b/internal/ctxwatch/context_watcher.go index 391f0b79..b39cb3ee 100644 --- a/internal/ctxwatch/context_watcher.go +++ b/internal/ctxwatch/context_watcher.go @@ -2,6 +2,7 @@ package ctxwatch import ( "context" + "sync" ) // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a @@ -10,8 +11,10 @@ type ContextWatcher struct { onCancel func() onUnwatchAfterCancel func() unwatchChan chan struct{} - watchInProgress bool - onCancelWasCalled bool + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool } // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. @@ -29,6 +32,9 @@ func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWat // Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + if cw.watchInProgress { panic("Watch already in progress") } @@ -54,6 +60,9 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { // Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was // called then onUnwatchAfterCancel will also be called. func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 6348b729..289606c3 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -59,7 +59,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } -func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw.Unwatch() // unwatch when not / never watching @@ -70,6 +70,19 @@ func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { cw.Unwatch() // double unwatch } +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64