diff --git a/pkg/controller/volume/selinuxwarning/cache/volumecache.go b/pkg/controller/volume/selinuxwarning/cache/volumecache.go index 50dd0d60156..9613b5e698e 100644 --- a/pkg/controller/volume/selinuxwarning/cache/volumecache.go +++ b/pkg/controller/volume/selinuxwarning/cache/volumecache.go @@ -22,6 +22,7 @@ import ( "sync" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" "k8s.io/kubernetes/pkg/controller/volume/selinuxwarning/translator" @@ -57,6 +58,9 @@ type volumeCache struct { seLinuxTranslator *translator.ControllerSELinuxTranslator // All volumes of all existing Pods. volumes map[v1.UniqueVolumeName]usedVolume + // Reverse index: maps each pod to the list of volumes it uses. + // The index is used during pod deletion. + podToVolumes map[cache.ObjectName]sets.Set[v1.UniqueVolumeName] } var _ VolumeCache = &volumeCache{} @@ -66,6 +70,7 @@ func NewVolumeLabelCache(seLinuxTranslator *translator.ControllerSELinuxTranslat return &volumeCache{ seLinuxTranslator: seLinuxTranslator, volumes: make(map[v1.UniqueVolumeName]usedVolume), + podToVolumes: make(map[cache.ObjectName]sets.Set[v1.UniqueVolumeName]), } } @@ -111,6 +116,9 @@ func (c *volumeCache) AddVolume(logger klog.Logger, volumeName v1.UniqueVolumeNa pods: newPodInfoListForPod(podKey, label, changePolicy), } c.volumes[volumeName] = volume + + // Add to reverse index + c.registerPodVolume(podKey, volumeName) return conflicts } @@ -129,6 +137,9 @@ func (c *volumeCache) AddVolume(logger klog.Logger, volumeName v1.UniqueVolumeNa // Add the updated pod info to the cache volume.pods[podKey] = podInfo + // Add to reverse index + c.registerPodVolume(podKey, volumeName) + // Emit conflicts for the pod for otherPodKey, otherPodInfo := range volume.pods { if otherPodInfo.changePolicy != changePolicy { @@ -177,12 +188,28 @@ func (c *volumeCache) DeletePod(logger klog.Logger, podKey cache.ObjectName) { defer c.mutex.Unlock() defer c.dump(logger) - for volumeName, volume := range c.volumes { + // Use reverse index to only iterate through volumes this pod actually uses. + for volumeName := range c.podToVolumes[podKey] { + volume, found := c.volumes[volumeName] + if !found { + continue + } delete(volume.pods, podKey) if len(volume.pods) == 0 { delete(c.volumes, volumeName) } } + delete(c.podToVolumes, podKey) +} + +// registerPodVolume adds volumeName to the pod volume index. +// Make sure to hold c.mutex when calling this function. +func (c *volumeCache) registerPodVolume(podKey cache.ObjectName, volumeName v1.UniqueVolumeName) { + if podVolumes, ok := c.podToVolumes[podKey]; ok { + podVolumes.Insert(volumeName) + } else { + c.podToVolumes[podKey] = sets.New(volumeName) + } } func (c *volumeCache) dump(logger klog.Logger) { @@ -214,6 +241,22 @@ func (c *volumeCache) dump(logger klog.Logger) { logger.Info(" pod", "pod", podKey, "seLinuxLabel", podInfo.seLinuxLabel, "changePolicy", podInfo.changePolicy) } } + + // Collect all pods, sort them and print the associated volumes. + podKeys := make([]cache.ObjectName, 0, len(c.podToVolumes)) + for podKey := range c.podToVolumes { + podKeys = append(podKeys, podKey) + } + sort.Slice(podKeys, func(i, j int) bool { + return podKeys[i].String() < podKeys[j].String() + }) + + logger.Info("VolumeCache reverse index dump:") + for _, podKey := range podKeys { + podVolumes := sets.List(c.podToVolumes[podKey]) + slices.Sort(podVolumes) + logger.Info(" pod", "pod", podKey, "volumes", podVolumes) + } } // GetPodsForCSIDriver returns all pods that use volumes with the given CSI driver. diff --git a/pkg/controller/volume/selinuxwarning/cache/volumecache_test.go b/pkg/controller/volume/selinuxwarning/cache/volumecache_test.go index 5bba301b692..66885e519bf 100644 --- a/pkg/controller/volume/selinuxwarning/cache/volumecache_test.go +++ b/pkg/controller/volume/selinuxwarning/cache/volumecache_test.go @@ -45,6 +45,39 @@ func sortConflicts(conflicts []Conflict) { }) } +// verifyReverseIndexConsistency checks that forward and reverse indexes are symmetric +func verifyReverseIndexConsistency(t *testing.T, c *volumeCache) { + t.Helper() + + // For every (pod, volume) in reverse index, verify it exists in forward index. + for podKey, volumes := range c.podToVolumes { + for volumeName := range volumes { + volume, found := c.volumes[volumeName] + if !found { + t.Errorf("Reverse index has pod %s -> volume %s, but volume not in forward index", podKey, volumeName) + continue + } + if _, found := volume.pods[podKey]; !found { + t.Errorf("Reverse index has pod %s -> volume %s, but pod not in volume's pod list", podKey, volumeName) + } + } + } + + // For every (volume, pod) in forward index, verify it exists in reverse index. + for volumeName, volume := range c.volumes { + for podKey := range volume.pods { + podVolumes, found := c.podToVolumes[podKey] + if !found { + t.Errorf("Forward index has volume %s -> pod %s, but pod not in reverse index", volumeName, podKey) + continue + } + if _, found := podVolumes[volumeName]; !found { + t.Errorf("Forward index has volume %s -> pod %s, but volume not in pod's volume list", volumeName, podKey) + } + } + } +} + // Delete all items in a bigger cache and check it's empty func TestVolumeCache_DeleteAll(t *testing.T) { var podsToDelete []cache.ObjectName @@ -69,6 +102,8 @@ func TestVolumeCache_DeleteAll(t *testing.T) { t.Log("Before deleting all pods:") c.dump(dumpLogger) + verifyReverseIndexConsistency(t, c) + // Act: delete all pods for _, podKey := range podsToDelete { c.DeletePod(logger, podKey) @@ -79,6 +114,12 @@ func TestVolumeCache_DeleteAll(t *testing.T) { t.Errorf("Expected cache to be empty, got %d volumes", len(c.volumes)) c.dump(dumpLogger) } + + // Assert: the reverse index is also empty + if len(c.podToVolumes) != 0 { + t.Errorf("Expected reverse index to be empty, got %d pods", len(c.podToVolumes)) + } + verifyReverseIndexConsistency(t, c) } type podWithVolume struct { @@ -442,6 +483,9 @@ func TestVolumeCache_AddVolumeSendConflicts(t *testing.T) { t.Errorf("pod %s has unexpected info: %+v", podKey, existingInfo) } + // Verify reverse index consistency + verifyReverseIndexConsistency(t, c) + // Act again: get the conflicts via SendConflicts ch := make(chan Conflict) go func() {