diff --git a/pkg/scheduler/framework/plugins/nodevolumelimits/csi.go b/pkg/scheduler/framework/plugins/nodevolumelimits/csi.go index a61398b23c1..7ba415589ba 100644 --- a/pkg/scheduler/framework/plugins/nodevolumelimits/csi.go +++ b/pkg/scheduler/framework/plugins/nodevolumelimits/csi.go @@ -23,11 +23,12 @@ import ( v1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/rand" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" corelisters "k8s.io/client-go/listers/core/v1" storagelisters "k8s.io/client-go/listers/storage/v1" + "k8s.io/client-go/tools/cache" ephemeral "k8s.io/component-helpers/storage/ephemeral" storagehelpers "k8s.io/component-helpers/storage/volume" csitrans "k8s.io/csi-translation-lib" @@ -63,6 +64,7 @@ type CSILimits struct { scLister storagelisters.StorageClassLister vaLister storagelisters.VolumeAttachmentLister csiDriverLister storagelisters.CSIDriverLister + vaIndexer cache.Indexer randomVolumeIDPrefix string enableVolumeLimitScaling bool @@ -588,6 +590,14 @@ func (pl *CSILimits) getCSIDriverInfoFromSC(logger klog.Logger, csiNode *storage return provisioner, volumeHandle } +func volumeAttachmentIndexer(obj interface{}) ([]string, error) { + va, ok := obj.(*storagev1.VolumeAttachment) + if !ok { + return []string{}, nil + } + return []string{va.Spec.NodeName}, nil +} + // NewCSI initializes a new plugin and returns it. func NewCSI(_ context.Context, _ runtime.Object, handle fwk.Handle, fts feature.Features) (fwk.Plugin, error) { informerFactory := handle.SharedInformerFactory() @@ -596,6 +606,12 @@ func NewCSI(_ context.Context, _ runtime.Object, handle fwk.Handle, fts feature. scLister := informerFactory.Storage().V1().StorageClasses().Lister() vaLister := informerFactory.Storage().V1().VolumeAttachments().Lister() csiDriverLister := informerFactory.Storage().V1().CSIDrivers().Lister() + vaInformer := informerFactory.Storage().V1().VolumeAttachments().Informer() + if err := vaInformer.AddIndexers(cache.Indexers{vaIndexKey: volumeAttachmentIndexer}); err != nil { + if vaInformer.GetIndexer().GetIndexers()[vaIndexKey] == nil { + return nil, fmt.Errorf("failed to add index to VA informer: %w", err) + } + } csiTranslator := csitrans.New() return &CSILimits{ @@ -608,6 +624,7 @@ func NewCSI(_ context.Context, _ runtime.Object, handle fwk.Handle, fts feature. enableVolumeLimitScaling: fts.EnableVolumeLimitScaling, randomVolumeIDPrefix: rand.String(32), translator: csiTranslator, + vaIndexer: vaInformer.GetIndexer(), }, nil } @@ -627,14 +644,22 @@ func getVolumeLimits(csiNode *storagev1.CSINode) map[string]int64 { return nodeVolumeLimits } +const vaIndexKey = "va.spec.nodename" + // getNodeVolumeAttachmentInfo returns a map of volumeID to driver name for the given node. func (pl *CSILimits) getNodeVolumeAttachmentInfo(logger klog.Logger, nodeName string) (map[string]string, error) { volumeAttachments := make(map[string]string) - vas, err := pl.vaLister.List(labels.Everything()) + vas, err := pl.vaIndexer.ByIndex(vaIndexKey, nodeName) if err != nil { return nil, err } - for _, va := range vas { + for _, vao := range vas { + va, ok := vao.(*storagev1.VolumeAttachment) + if !ok { + utilruntime.HandleErrorWithLogger(logger, fmt.Errorf("unexpected object type in volume attachment indexer: %v", vao), + "volume indexer not available") + continue + } if va.Spec.NodeName == nodeName { if va.Spec.Attacher == "" { logger.V(5).Info("VolumeAttachment has no attacher", "VolumeAttachment", klog.KObj(va)) diff --git a/pkg/scheduler/framework/plugins/nodevolumelimits/csi_test.go b/pkg/scheduler/framework/plugins/nodevolumelimits/csi_test.go index 5ba3cd6c63a..d2c18915bfa 100644 --- a/pkg/scheduler/framework/plugins/nodevolumelimits/csi_test.go +++ b/pkg/scheduler/framework/plugins/nodevolumelimits/csi_test.go @@ -29,9 +29,13 @@ import ( apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/rand" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/tools/cache" csitrans "k8s.io/csi-translation-lib" csilibplugins "k8s.io/csi-translation-lib/plugins" fwk "k8s.io/kube-scheduler/framework" @@ -641,16 +645,23 @@ func TestCSILimits(t *testing.T) { enableMigrationOnNode(csiNode, csilibplugins.AWSEBSInTreePluginName) } csiTranslator := csitrans.New() + fakecli := buildFakeClientWithVALister(test.vaCount, test.driverNames...) + informerFactory := informers.NewSharedInformerFactory(fakecli, 0) + if err := informerFactory.Storage().V1().VolumeAttachments().Informer().AddIndexers(cache.Indexers{vaIndexKey: volumeAttachmentIndexer}); err != nil { + t.Error(err) + } + _, ctx := ktesting.NewTestContext(t) + informerFactory.Start(ctx.Done()) + informerFactory.WaitForCacheSync(ctx.Done()) p := &CSILimits{ csiManager: NewCSIManager(getFakeCSINodeLister(csiNode)), pvLister: getFakeCSIPVLister(test.filterName, test.driverNames...), pvcLister: append(getFakeCSIPVCLister(test.filterName, scName, test.driverNames...), test.extraClaims...), scLister: getFakeCSIStorageClassLister(scName, test.driverNames[0]), - vaLister: getFakeVolumeAttachmentLister(test.vaCount, test.driverNames...), + vaIndexer: informerFactory.Storage().V1().VolumeAttachments().Informer().GetIndexer(), randomVolumeIDPrefix: rand.String(32), translator: csiTranslator, } - _, ctx := ktesting.NewTestContext(t) _, gotPreFilterStatus := p.PreFilter(ctx, nil, test.newPod, nil) if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus, statusCmpOpts...); diff != "" { t.Errorf("PreFilter status does not match (-want, +got):\n%s", diff) @@ -1074,12 +1085,12 @@ func TestCSILimitsAfterCSINodeUpdatedQHint(t *testing.T) { } } -func getFakeVolumeAttachmentLister(count int, driverNames ...string) tf.VolumeAttachmentLister { - vaLister := tf.VolumeAttachmentLister{} +func buildFakeClientWithVALister(count int, driverNames ...string) *fake.Clientset { + vas := []runtime.Object{} for _, driver := range driverNames { for j := 0; j < count; j++ { pvName := fmt.Sprintf("csi-%s-%d", driver, j) - va := storagev1.VolumeAttachment{ + va := &storagev1.VolumeAttachment{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("va-%s-%d", driver, j), }, @@ -1091,11 +1102,13 @@ func getFakeVolumeAttachmentLister(count int, driverNames ...string) tf.VolumeAt }, }, } - vaLister = append(vaLister, va) + vas = append(vas, va) } } - return vaLister + fakeCli := fake.NewClientset(vas...) + return fakeCli } + func getFakeCSIPVLister(volumeName string, driverNames ...string) tf.PersistentVolumeLister { pvLister := tf.PersistentVolumeLister{} for _, driver := range driverNames { @@ -1351,6 +1364,14 @@ func TestVolumeLimitScalingGate(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { node, csiNode := getNodeWithPodAndVolumeLimits(tt.limitSource, []*v1.Pod{}, tt.limit, ebsCSIDriverName) + fakecli := buildFakeClientWithVALister(0, ebsCSIDriverName) + informerFactory := informers.NewSharedInformerFactory(fakecli, 0) + if err := informerFactory.Storage().V1().VolumeAttachments().Informer().AddIndexers(cache.Indexers{vaIndexKey: volumeAttachmentIndexer}); err != nil { + t.Error(err) + } + _, ctx := ktesting.NewTestContext(t) + informerFactory.Start(ctx.Done()) + informerFactory.WaitForCacheSync(ctx.Done()) csiTranslator := csitrans.New() p := &CSILimits{ @@ -1358,7 +1379,8 @@ func TestVolumeLimitScalingGate(t *testing.T) { pvLister: getFakeCSIPVLister("csi", ebsCSIDriverName), pvcLister: getFakeCSIPVCLister("csi", scName, ebsCSIDriverName), scLister: getFakeCSIStorageClassLister(scName, ebsCSIDriverName), - vaLister: getFakeVolumeAttachmentLister(0, ebsCSIDriverName), + vaLister: informerFactory.Storage().V1().VolumeAttachments().Lister(), + vaIndexer: informerFactory.Storage().V1().VolumeAttachments().Informer().GetIndexer(), csiDriverLister: func() fakeCSIDriverLister { if tt.csiDriverPresent { if tt.preventPodSchedulingIfMissing { @@ -1373,7 +1395,6 @@ func TestVolumeLimitScalingGate(t *testing.T) { translator: csiTranslator, } - _, ctx := ktesting.NewTestContext(t) // Ensure PreFilter doesn't skip _, preStatus := p.PreFilter(ctx, nil, newPod, nil) if preStatus.Code() == fwk.Skip {