From b7106b725097932e9ae2eb73dd641ce7fbb80317 Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Sat, 15 Nov 2025 18:57:34 -0700 Subject: [PATCH] feat: add MutexMap to caching module --- modules/cache/mutex_map.go | 60 +++++++++++++++ modules/cache/mutex_map_test.go | 131 ++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 modules/cache/mutex_map.go create mode 100644 modules/cache/mutex_map_test.go diff --git a/modules/cache/mutex_map.go b/modules/cache/mutex_map.go new file mode 100644 index 0000000000..beb27f32c8 --- /dev/null +++ b/modules/cache/mutex_map.go @@ -0,0 +1,60 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package cache + +import ( + "sync" +) + +// MutexMap is basically a map[string]sync.Mutex which allows you to have one mutex per string key being locked. Unlike +// a map[string]sync.Mutex, this map will automatically remove the Mutexes from itself when they are not being waited +// for, preventing resource waste. It does this by keeping a reference count of the current Lock calls for the given +// key. +type MutexMap struct { + mu sync.Mutex // mutex to be held when accessing mutexMap + mutexMap map[string]*refcountMutex +} + +type refcountMutex struct { + refCount int // access to refCount is protected by the MutexMap's mu + + sync.Mutex +} + +// Locks the given key, and returns a function that must be invoked to unlock the key. +func (m *MutexMap) Lock(key string) func() { + m.mu.Lock() + if m.mutexMap == nil { + m.mutexMap = make(map[string]*refcountMutex) + } + mutex, ok := m.mutexMap[key] + if !ok { + mutex = &refcountMutex{} + m.mutexMap[key] = mutex + } + mutex.refCount++ + m.mu.Unlock() + + mutex.Lock() + + unlockPending := true + + return func() { + if !unlockPending { + // unlocking twice would cause incorrect reference counts and might release another goroutine's mutex -- try + // to detect and panic so that this programming error can be found closest to the source. + panic("MutexMap unlock invoked twice") + } + + unlockPending = false + mutex.Unlock() + + m.mu.Lock() + mutex.refCount-- + if mutex.refCount == 0 { + delete(m.mutexMap, key) + } + m.mu.Unlock() + } +} diff --git a/modules/cache/mutex_map_test.go b/modules/cache/mutex_map_test.go new file mode 100644 index 0000000000..324b3228d8 --- /dev/null +++ b/modules/cache/mutex_map_test.go @@ -0,0 +1,131 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package cache + +import ( + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMutexMap_BasicLockUnlock(t *testing.T) { + mm := &MutexMap{} + + unlock := mm.Lock("test-key") + unlock() + + // Should be able to lock again + unlock2 := mm.Lock("test-key") + unlock2() +} + +func TestMutexMap_ConcurrentSameKey(t *testing.T) { + mm := &MutexMap{} + var anotherLockActive atomic.Bool + var firstError atomic.Value + var wg sync.WaitGroup + + for range 10 { + wg.Go(func() { + unlock := mm.Lock("shared-key") + defer unlock() + + // should *not* find that another goroutine has put `true` into here. + swapped := anotherLockActive.CompareAndSwap(false, true) + if !swapped { + firstError.CompareAndSwap(nil, "anotherLockActive was true!") + } + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) // jitter the goroutines to ensure no serial execution + anotherLockActive.Store(false) + }) + } + + wg.Wait() + + if err := firstError.Load(); err != nil { + t.Fatal(err) + } +} + +func TestMutexMap_DifferentKeys(t *testing.T) { + mm := &MutexMap{} + done := make(chan bool, 1) + + go func() { + // If these somehow refered to the same underlying `sync.Mutex`, because `sync.Mutex` is not re-entrant this would + // never complete. + unlock1 := mm.Lock("test-key-1") + unlock2 := mm.Lock("test-key-2") + unlock3 := mm.Lock("test-key-3") + unlock1() + unlock2() + unlock3() + done <- true + }() + + select { + case <-done: + // Success + case <-time.After(1 * time.Second): // early timeout so that we don't wait for t.Deadline() + t.Fatal("test incomplete after timeout, indicating a locking bug") + } +} + +func TestMutexMap_SimpleCleanup(t *testing.T) { + mm := &MutexMap{} + unlock1 := mm.Lock("test-key-1") + + mm.mu.Lock() + assert.Len(t, mm.mutexMap, 1) + mm.mu.Unlock() + + unlock1() + + mm.mu.Lock() + assert.Empty(t, mm.mutexMap) + mm.mu.Unlock() +} + +func TestMutexMap_ConcurrentCleanup(t *testing.T) { + mm := &MutexMap{} + var foundRefGreaterThanOne atomic.Bool + var wg sync.WaitGroup + + for range 10 { + wg.Go(func() { + unlock := mm.Lock("shared-key") + defer unlock() + + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) // jitter the goroutines to ensure no serial execution + + mm.mu.Lock() + rcMutex := mm.mutexMap["shared-key"] + if rcMutex.refCount > 1 { + foundRefGreaterThanOne.Store(true) + } + mm.mu.Unlock() + }) + } + + wg.Wait() + + assert.True(t, foundRefGreaterThanOne.Load(), "expected to find a refCount > 1") + + mm.mu.Lock() + assert.Empty(t, mm.mutexMap) + mm.mu.Unlock() +} + +func TestMutexMap_UnlockTwice(t *testing.T) { + mm := &MutexMap{} + assert.Panics(t, func() { + unlock := mm.Lock("test") + unlock() + unlock() + }) +}