From 4a326b0196eb44826e8aca038b528ea49e4293e4 Mon Sep 17 00:00:00 2001 From: Maciej Wyrzuc Date: Fri, 28 Nov 2025 12:49:19 +0000 Subject: [PATCH] Preempt pods in prebind phase without delete calls. This change allows the preemption to preempt a pod that is not yet bound, but is already in prebind phase) without issuing a delete call to the apiserver. Pods are added to a special map of pods currently in prebind phaseand preemption can cancel the context that is used for given pod prebind phase , allowing it to gracefully handle error in the same manner as errors coming out from prebind plugins. This results in pods being pushed to backoff queue, allowing them to be rescheduled in upcoming scheduling cycles. --- .../default_preemption_test.go | 2 + .../framework/preemption/executor.go | 13 +- .../framework/preemption/executor_test.go | 128 +++++++++++ pkg/scheduler/framework/runtime/framework.go | 32 +++ .../framework/runtime/pods_in_prebind_map.go | 98 +++++++++ .../runtime/pods_in_prebind_map_test.go | 121 +++++++++++ pkg/scheduler/schedule_one.go | 18 +- pkg/scheduler/schedule_one_podgroup_test.go | 2 + pkg/scheduler/schedule_one_test.go | 58 +++++ pkg/scheduler/scheduler.go | 4 + .../kube-scheduler/framework/interface.go | 22 ++ .../scheduler/preemption/preemption_test.go | 200 ++++++++++++++++++ 12 files changed, 696 insertions(+), 2 deletions(-) create mode 100644 pkg/scheduler/framework/runtime/pods_in_prebind_map.go create mode 100644 pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go diff --git a/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go b/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go index 41592e22717..27de7964105 100644 --- a/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go +++ b/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go @@ -477,6 +477,7 @@ func TestPostFilter(t *testing.T) { frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(tt.pods, tt.nodes)), frameworkruntime.WithLogger(logger), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), ) if err != nil { t.Fatal(err) @@ -2247,6 +2248,7 @@ func TestPreempt(t *testing.T) { frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(testPods, nodes)), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithLogger(logger), frameworkruntime.WithPodActivator(&fakePodActivator{}), ) diff --git a/pkg/scheduler/framework/preemption/executor.go b/pkg/scheduler/framework/preemption/executor.go index 0553c11b5df..739e9d7e299 100644 --- a/pkg/scheduler/framework/preemption/executor.go +++ b/pkg/scheduler/framework/preemption/executor.go @@ -77,6 +77,7 @@ func newExecutor(fh fwk.Handle) *Executor { logger := klog.FromContext(ctx) skipAPICall := false + eventMessage := fmt.Sprintf("Preempted by pod %v on node %v", preemptor.UID, c.Name()) // If the victim is a WaitingPod, try to preempt it without a delete call (victim will go back to backoff queue). // Otherwise we should delete the victim. if waitingPod := e.fh.GetWaitingPod(victim.UID); waitingPod != nil { @@ -84,6 +85,14 @@ func newExecutor(fh fwk.Handle) *Executor { logger.V(2).Info("Preemptor pod preempted a waiting pod", "preemptor", klog.KObj(preemptor), "waitingPod", klog.KObj(victim), "node", c.Name()) skipAPICall = true } + } else if podInPreBind := e.fh.GetPodInPreBind(victim.UID); podInPreBind != nil { + // If the victim is in the preBind cancel the binding process. + if podInPreBind.CancelPod(fmt.Sprintf("preempted by %s", pluginName)) { + logger.V(2).Info("Preemptor pod rejected a pod in preBind", "preemptor", klog.KObj(preemptor), "podInPreBind", klog.KObj(victim), "node", c.Name()) + skipAPICall = true + } else { + logger.V(5).Info("Failed to reject a pod in preBind, falling back to deletion via api call", "preemptor", klog.KObj(preemptor), "podInPreBind", klog.KObj(victim), "node", c.Name()) + } } if !skipAPICall { condition := &v1.PodCondition{ @@ -114,9 +123,11 @@ func newExecutor(fh fwk.Handle) *Executor { return nil } logger.V(2).Info("Preemptor Pod preempted victim Pod", "preemptor", klog.KObj(preemptor), "victim", klog.KObj(victim), "node", c.Name()) + } else { + eventMessage += " (in kube-scheduler memory)." } - fh.EventRecorder().Eventf(victim, preemptor, v1.EventTypeNormal, "Preempted", "Preempting", "Preempted by pod %v on node %v", preemptor.UID, c.Name()) + fh.EventRecorder().Eventf(victim, preemptor, v1.EventTypeNormal, "Preempted", "Preempting", eventMessage) return nil } diff --git a/pkg/scheduler/framework/preemption/executor_test.go b/pkg/scheduler/framework/preemption/executor_test.go index 908a92cff8e..0efef7c9274 100644 --- a/pkg/scheduler/framework/preemption/executor_test.go +++ b/pkg/scheduler/framework/preemption/executor_test.go @@ -47,6 +47,7 @@ import ( apidispatcher "k8s.io/kubernetes/pkg/scheduler/backend/api_dispatcher" internalcache "k8s.io/kubernetes/pkg/scheduler/backend/cache" internalqueue "k8s.io/kubernetes/pkg/scheduler/backend/queue" + "k8s.io/kubernetes/pkg/scheduler/framework" apicalls "k8s.io/kubernetes/pkg/scheduler/framework/api_calls" "k8s.io/kubernetes/pkg/scheduler/framework/plugins/defaultbinder" "k8s.io/kubernetes/pkg/scheduler/framework/plugins/queuesort" @@ -589,6 +590,7 @@ func TestPrepareCandidate(t *testing.T) { frameworkruntime.WithLogger(logger), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(tt.testPods, nodes)), frameworkruntime.WithPodNominator(nominator), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), @@ -817,6 +819,7 @@ func TestPrepareCandidateAsyncSetsPreemptingSets(t *testing.T) { frameworkruntime.WithLogger(logger), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(testPods, nodes)), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), frameworkruntime.WithPodNominator(internalqueue.NewSchedulingQueue(nil, informerFactory)), @@ -1054,6 +1057,7 @@ func TestAsyncPreemptionFailure(t *testing.T) { frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), frameworkruntime.WithPodActivator(fakeActivator), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(snapshotPods, []*v1.Node{st.MakeNode().Name(node1Name).Obj()})), ) if err != nil { @@ -1191,3 +1195,127 @@ func TestRemoveNominatedNodeName(t *testing.T) { } } } + +func TestPreemptPod(t *testing.T) { + preemptorPod := st.MakePod().Name("p").UID("p").Priority(highPriority).Obj() + victimPod := st.MakePod().Name("v").UID("v").Priority(midPriority).Obj() + + tests := []struct { + name string + addVictimToPrebind bool + addVictimToWaiting bool + expectCancel bool + expectedActions []string + }{ + { + name: "victim is in preBind, context should be cancelled", + addVictimToPrebind: true, + addVictimToWaiting: false, + expectCancel: true, + expectedActions: []string{}, + }, + { + name: "victim is in waiting pods, it should be rejected (no calls to apiserver)", + addVictimToPrebind: false, + addVictimToWaiting: true, + expectCancel: false, + expectedActions: []string{}, + }, + { + name: "victim is not in waiting/preBind pods, pod should be deleted", + addVictimToPrebind: false, + addVictimToWaiting: false, + expectCancel: false, + expectedActions: []string{"patch", "delete"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podsInPreBind := frameworkruntime.NewPodsInPreBindMap() + waitingPods := frameworkruntime.NewWaitingPodsMap() + registeredPlugins := append([]tf.RegisterPluginFunc{ + tf.RegisterQueueSortPlugin(queuesort.Name, queuesort.New)}, + tf.RegisterBindPlugin(defaultbinder.Name, defaultbinder.New), + tf.RegisterPermitPlugin(waitingPermitPluginName, newWaitingPermitPlugin), + ) + objs := []runtime.Object{preemptorPod, victimPod} + cs := clientsetfake.NewClientset(objs...) + informerFactory := informers.NewSharedInformerFactory(cs, 0) + eventBroadcaster := events.NewBroadcaster(&events.EventSinkImpl{Interface: cs.EventsV1()}) + logger, ctx := ktesting.NewTestContext(t) + + fwk, err := tf.NewFramework( + ctx, + registeredPlugins, "", + frameworkruntime.WithClientSet(cs), + frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot([]*v1.Pod{}, []*v1.Node{})), + frameworkruntime.WithInformerFactory(informerFactory), + frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(podsInPreBind), + frameworkruntime.WithLogger(logger), + frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), + ) + if err != nil { + t.Fatal(err) + } + var victimCtx context.Context + var cancel context.CancelCauseFunc + if tt.addVictimToPrebind { + victimCtx, cancel = context.WithCancelCause(context.Background()) + fwk.AddPodInPreBind(victimPod.UID, cancel) + } + if tt.addVictimToWaiting { + status := fwk.RunPermitPlugins(ctx, framework.NewCycleState(), victimPod, "fake-node") + if !status.IsWait() { + t.Fatalf("Failed to add a pod to waiting list") + } + } + pe := NewEvaluator("FakePreemptionScorePostFilter", fwk, &FakePreemptionScorePostFilterPlugin{}, false) + + err = pe.PreemptPod(ctx, &candidate{}, preemptorPod, victimPod, "test-plugin") + if err != nil { + t.Fatal(err) + } + if tt.expectCancel { + if victimCtx.Err() == nil { + t.Errorf("Context of a binding pod should be cancelled") + } + } else { + if victimCtx != nil && victimCtx.Err() != nil { + t.Errorf("Context of a normal pod should not be cancelled") + } + } + + // check if the API call was made + actions := cs.Actions() + if len(actions) != len(tt.expectedActions) { + t.Errorf("Expected %d actions, but got %d", len(tt.expectedActions), len(actions)) + } + for i, action := range actions { + if action.GetVerb() != tt.expectedActions[i] { + t.Errorf("Expected action %s, but got %s", tt.expectedActions[i], action.GetVerb()) + } + } + }) + } +} + +// waitingPermitPlugin is a PermitPlugin that always returns Wait. +type waitingPermitPlugin struct{} + +var _ fwk.PermitPlugin = &waitingPermitPlugin{} + +const waitingPermitPluginName = "waitingPermitPlugin" + +func newWaitingPermitPlugin(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) { + return &waitingPermitPlugin{}, nil +} + +func (pl *waitingPermitPlugin) Name() string { + return waitingPermitPluginName +} + +func (pl *waitingPermitPlugin) Permit(ctx context.Context, _ fwk.CycleState, _ *v1.Pod, nodeName string) (*fwk.Status, time.Duration) { + return fwk.NewStatus(fwk.Wait, ""), 10 * time.Second +} diff --git a/pkg/scheduler/framework/runtime/framework.go b/pkg/scheduler/framework/runtime/framework.go index 88f659389dd..fdf9ea7e8bf 100644 --- a/pkg/scheduler/framework/runtime/framework.go +++ b/pkg/scheduler/framework/runtime/framework.go @@ -58,6 +58,7 @@ type frameworkImpl struct { registry Registry snapshotSharedLister fwk.SharedLister waitingPods *waitingPodsMap + podsInPreBind *podsInPreBindMap scorePluginWeight map[string]int preEnqueuePlugins []fwk.PreEnqueuePlugin enqueueExtensions []fwk.EnqueueExtensions @@ -153,6 +154,7 @@ type frameworkOptions struct { captureProfile CaptureProfile parallelizer parallelize.Parallelizer waitingPods *waitingPodsMap + podsInPreBind *podsInPreBindMap apiDispatcher *apidispatcher.APIDispatcher workloadManager fwk.WorkloadManager logger *klog.Logger @@ -285,6 +287,13 @@ func WithWaitingPods(wp *waitingPodsMap) Option { } } +// WithPodsInPreBind sets podsInPreBind for the scheduling frameworkImpl. +func WithPodsInPreBind(bp *podsInPreBindMap) Option { + return func(o *frameworkOptions) { + o.podsInPreBind = bp + } +} + // WithLogger overrides the default logger from k8s.io/klog. func WithLogger(logger klog.Logger) Option { return func(o *frameworkOptions) { @@ -323,6 +332,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler sharedCSIManager: options.sharedCSIManager, scorePluginWeight: make(map[string]int), waitingPods: options.waitingPods, + podsInPreBind: options.podsInPreBind, clientSet: options.clientSet, kubeConfig: options.kubeConfig, eventRecorder: options.eventRecorder, @@ -1456,6 +1466,10 @@ func (f *frameworkImpl) RunPreBindPlugins(ctx context.Context, state fwk.CycleSt return plStatus } err := plStatus.AsError() + if errors.Is(err, context.Canceled) { + err = context.Cause(ctx) + } + logger.Error(err, "Plugin failed", "plugin", pl.Name(), "pod", klog.KObj(pod), "node", nodeName) return fwk.AsStatus(fmt.Errorf("running PreBind plugin %q: %w", pl.Name(), err)) } @@ -1894,6 +1908,24 @@ func (f *frameworkImpl) RejectWaitingPod(uid types.UID) bool { return false } +// AddPodInPreBind adds a pod to the pods in preBind list. +func (f *frameworkImpl) AddPodInPreBind(uid types.UID, cancel context.CancelCauseFunc) { + f.podsInPreBind.add(uid, cancel) +} + +// GetPodInPreBind returns a pod that is in the binding cycle but before it is bound given its UID. +func (f *frameworkImpl) GetPodInPreBind(uid types.UID) fwk.PodInPreBind { + if bp := f.podsInPreBind.get(uid); bp != nil { + return bp + } + return nil +} + +// RemovePodInPreBind removes a pod from the pods in preBind list. +func (f *frameworkImpl) RemovePodInPreBind(uid types.UID) { + f.podsInPreBind.remove(uid) +} + // HasFilterPlugins returns true if at least one filter plugin is defined. func (f *frameworkImpl) HasFilterPlugins() bool { return len(f.filterPlugins) > 0 diff --git a/pkg/scheduler/framework/runtime/pods_in_prebind_map.go b/pkg/scheduler/framework/runtime/pods_in_prebind_map.go new file mode 100644 index 00000000000..41f369e51da --- /dev/null +++ b/pkg/scheduler/framework/runtime/pods_in_prebind_map.go @@ -0,0 +1,98 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtime + +import ( + "context" + "errors" + "sync" + + "k8s.io/apimachinery/pkg/types" + fwk "k8s.io/kube-scheduler/framework" +) + +// podsInPreBindMap is a thread-safe map of pods currently in the preBind phase. +type podsInPreBindMap struct { + pods map[types.UID]*podInPreBind + mu sync.RWMutex +} + +// NewPodsInPreBindMap creates an empty podsInPreBindMap. +func NewPodsInPreBindMap() *podsInPreBindMap { + return &podsInPreBindMap{ + pods: make(map[types.UID]*podInPreBind), + } +} + +// get returns a pod from the map if it exists. +func (pbm *podsInPreBindMap) get(uid types.UID) *podInPreBind { + pbm.mu.RLock() + defer pbm.mu.RUnlock() + return pbm.pods[uid] +} + +// add adds a pod to map, overwriting existing one. +func (pbm *podsInPreBindMap) add(uid types.UID, cancel context.CancelCauseFunc) { + pbm.mu.Lock() + defer pbm.mu.Unlock() + pbm.pods[uid] = &podInPreBind{cancel: cancel} +} + +// remove removes a pod from the map. +func (pbm *podsInPreBindMap) remove(uid types.UID) { + pbm.mu.Lock() + defer pbm.mu.Unlock() + delete(pbm.pods, uid) +} + +var _ fwk.PodInPreBind = &podInPreBind{} + +// podInPreBind describes a pod in the preBind phase, before the bind was called for a pod. +type podInPreBind struct { + finished bool + canceled bool + cancel context.CancelCauseFunc + mu sync.Mutex +} + +// CancelPod cancels the context running the preBind phase +// for a given pod. +func (bp *podInPreBind) CancelPod(message string) bool { + bp.mu.Lock() + defer bp.mu.Unlock() + if bp.finished { + return false + } + if !bp.canceled { + bp.cancel(errors.New(message)) + } + bp.canceled = true + return true +} + +// MarkPrebound marks the pod as finished with preBind phase +// of binding cycle, making it impossible to cancel +// the binding cycle for it. +func (bp *podInPreBind) MarkPrebound() bool { + bp.mu.Lock() + defer bp.mu.Unlock() + if bp.canceled { + return false + } + bp.finished = true + return true +} diff --git a/pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go b/pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go new file mode 100644 index 00000000000..add5800f983 --- /dev/null +++ b/pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go @@ -0,0 +1,121 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtime + +import ( + "context" + "testing" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2/ktesting" +) + +const ( + testPodUID = types.UID("pod-1") +) + +func TestPodsInBindingMap(t *testing.T) { + tests := []struct { + name string + scenario func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) + }{ + { + name: "add and get", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + if pod == nil { + t.Fatalf("expected pod to be in map") + } + }, + }, + { + name: "add, remove and get", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + m.remove(testPodUID) + if m.get(testPodUID) != nil { + t.Errorf("expected pod to be removed from map") + } + }, + }, + { + name: "Verify CancelPod logic", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + + // First cancel should succeed + if !pod.CancelPod("test cancel") { + t.Errorf("First CancelPod should return true") + } + if ctx.Err() == nil { + t.Errorf("Context should be cancelled") + } + + // Second cancel should also return true + if !pod.CancelPod("test cancel") { + t.Errorf("Second CancelPod should return true") + } + }, + }, + { + name: "Verify MarkBound logic", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + + if !pod.MarkPrebound() { + t.Errorf("MarkBound should return true for fresh pod") + } + if !pod.finished { + t.Errorf("finished should be true") + } + + // Try to cancel after binding + if pod.CancelPod("test cancel") { + t.Errorf("CancelPod should return false after MarkBound") + } + if ctx.Err() != nil { + t.Errorf("Context should NOT be cancelled if MarkBound succeeded first") + } + }, + }, + { + name: "Verify MarkBound fails on cancelled pod", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + + pod.CancelPod("test cancel") + if pod.MarkPrebound() { + t.Errorf("MarkBound should return false on cancelled pod") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := NewPodsInPreBindMap() + _, ctx := ktesting.NewTestContext(t) + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + tt.scenario(t, ctx, cancel, m) + }) + } +} diff --git a/pkg/scheduler/schedule_one.go b/pkg/scheduler/schedule_one.go index db5bba104fd..064e37b58b4 100644 --- a/pkg/scheduler/schedule_one.go +++ b/pkg/scheduler/schedule_one.go @@ -399,8 +399,9 @@ func (sched *Scheduler) bindingCycle( assumedPod := assumedPodInfo.Pod + var preFlightStatus *fwk.Status if sched.nominatedNodeNameForExpectationEnabled { - preFlightStatus := schedFramework.RunPreBindPreFlights(ctx, state, assumedPod, scheduleResult.SuggestedHost) + preFlightStatus = schedFramework.RunPreBindPreFlights(ctx, state, assumedPod, scheduleResult.SuggestedHost) if preFlightStatus.Code() == fwk.Error || // Unschedulable status is not supported in PreBindPreFlight and hence we regard it as an error. preFlightStatus.IsRejected() { @@ -444,11 +445,26 @@ func (sched *Scheduler) bindingCycle( // we can free the cluster events stored in the scheduling queue sooner, which is worth for busy clusters memory consumption wise. sched.SchedulingQueue.Done(assumedPod.UID) + // If we are going to run prebind plugins we put the pod in binding map to optimize preemption. + if preFlightStatus.IsSuccess() { + var podInPreBindCancel context.CancelCauseFunc + ctx, podInPreBindCancel = context.WithCancelCause(ctx) + defer podInPreBindCancel(nil) + defer schedFramework.RemovePodInPreBind(assumedPod.UID) + schedFramework.AddPodInPreBind(assumedPod.UID, podInPreBindCancel) + } // Run "prebind" plugins. if status := schedFramework.RunPreBindPlugins(ctx, state, assumedPod, scheduleResult.SuggestedHost); !status.IsSuccess() { return status } + // Verify that pod was not preempted during prebinding. + bindingPod := schedFramework.GetPodInPreBind(assumedPod.UID) + if bindingPod != nil && !bindingPod.MarkPrebound() { + err := context.Cause(ctx) + return fwk.AsStatus(err) + } + // Run "bind" plugins. if status := sched.bind(ctx, schedFramework, assumedPod, scheduleResult.SuggestedHost, state); !status.IsSuccess() { return status diff --git a/pkg/scheduler/schedule_one_podgroup_test.go b/pkg/scheduler/schedule_one_podgroup_test.go index c7853355ffa..02976c06634 100644 --- a/pkg/scheduler/schedule_one_podgroup_test.go +++ b/pkg/scheduler/schedule_one_podgroup_test.go @@ -1038,10 +1038,12 @@ func TestSubmitPodGroupAlgorithmResult(t *testing.T) { }, } waitingPods := frameworkruntime.NewWaitingPodsMap() + podInPreBind := frameworkruntime.NewPodsInPreBindMap() schedFwk, err := frameworkruntime.NewFramework(ctx, registry, &profileCfg, frameworkruntime.WithClientSet(client), frameworkruntime.WithEventRecorder(events.NewFakeRecorder(100)), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(podInPreBind), ) if err != nil { t.Fatalf("Failed to create new framework: %v", err) diff --git a/pkg/scheduler/schedule_one_test.go b/pkg/scheduler/schedule_one_test.go index 88e79eef2a4..6dd19512f51 100644 --- a/pkg/scheduler/schedule_one_test.go +++ b/pkg/scheduler/schedule_one_test.go @@ -368,6 +368,30 @@ func newFakeNodeSelectorDependOnPodAnnotation(_ context.Context, _ runtime.Objec return &fakeNodeSelectorDependOnPodAnnotation{}, nil } +var blockingPreBindStartCh = make(chan struct{}, 1) + +var _ fwk.PreBindPlugin = &BlockingPreBindPlugin{} + +type BlockingPreBindPlugin struct { + handle fwk.Handle +} + +func (p *BlockingPreBindPlugin) Name() string { return "BlockingPreBindPlugin" } + +func (p *BlockingPreBindPlugin) PreBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status { + go func() { blockingPreBindStartCh <- struct{}{} }() + <-ctx.Done() + return fwk.AsStatus(ctx.Err()) +} + +func (p *BlockingPreBindPlugin) PreBindPreFlight(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) (*fwk.PreBindPreFlightResult, *fwk.Status) { + return &fwk.PreBindPreFlightResult{}, nil +} + +func NewBlockingPreBindPlugin(_ context.Context, _ runtime.Object, h fwk.Handle) (fwk.Plugin, error) { + return &BlockingPreBindPlugin{handle: h}, nil +} + type TestPlugin struct { name string } @@ -718,6 +742,10 @@ func TestSchedulerScheduleOne(t *testing.T) { asyncAPICallsEnabled *bool // If nil, the test case is run with both true and false scheduleAsPodGroup *bool + // postSchedulingCycle is run after ScheduleOne function returns + // (synchronous part of scheduling is done and binding goroutine is launched) + // and before blocking execution on waiting for event with eventReason + postSchedulingCycle func(context.Context, *Scheduler) } table := []testItem{ { @@ -898,6 +926,29 @@ func TestSchedulerScheduleOne(t *testing.T) { expectError: fmt.Errorf(`running PreBind plugin "FakePreBind": %w`, preBindErr), eventReason: "FailedScheduling", }, + { + name: "prebind pod cancelled during prebind (external preemption)", + sendPod: testPod, + registerPluginFuncs: []tf.RegisterPluginFunc{ + tf.RegisterPreBindPlugin("BlockingPreBindPlugin", NewBlockingPreBindPlugin), + }, + mockScheduleResult: scheduleResultOk, + postSchedulingCycle: func(ctx context.Context, sched *Scheduler) { + <-blockingPreBindStartCh // Wait for plugin to start + // Trigger preemption from "outside" + if bp := sched.Profiles[testSchedulerName].GetPodInPreBind(testPod.UID); bp != nil { + bp.CancelPod("context cancelled externally") + } + }, + nominatedNodeNameForExpectationEnabled: ptr.To(true), + expectAssumedPod: assignedTestPod, + expectErrorPod: assignedTestPod, + expectForgetPod: assignedTestPod, + expectNominatedNodeName: testNode.Name, + expectPodInBackoffQ: testPod, + expectError: fmt.Errorf(`running PreBind plugin "BlockingPreBindPlugin": context cancelled externally`), + eventReason: "FailedScheduling", + }, { name: "binding failed", sendPod: testPod, @@ -1064,6 +1115,7 @@ func TestSchedulerScheduleOne(t *testing.T) { frameworkruntime.WithAPIDispatcher(apiDispatcher), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, testSchedulerName)), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWorkloadManager(wm), ) @@ -1115,6 +1167,10 @@ func TestSchedulerScheduleOne(t *testing.T) { sched.nodeInfoSnapshot = internalcache.NewEmptySnapshot() sched.ScheduleOne(ctx) + if item.postSchedulingCycle != nil { + item.postSchedulingCycle(ctx, sched) + } + if item.podToAdmit != nil { for { if waitingPod := sched.Profiles[testSchedulerName].GetWaitingPod(item.podToAdmit.pod); waitingPod != nil { @@ -1684,6 +1740,7 @@ func TestScheduleOneMarksPodAsProcessedBeforePreBind(t *testing.T) { frameworkruntime.WithAPIDispatcher(apiDispatcher), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, testSchedulerName)), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), ) if err != nil { t.Fatal(err) @@ -4562,6 +4619,7 @@ func setupTestScheduler(ctx context.Context, t *testing.T, client clientset.Inte frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithPodNominator(schedulingQueue), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(snapshot), ) if apiDispatcher != nil { diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 16c7da3802a..27c546c3e9b 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -318,6 +318,9 @@ func New(ctx context.Context, // waitingPods holds all the pods that are in the scheduler and waiting in the permit stage waitingPods := frameworkruntime.NewWaitingPodsMap() + // podsInPreBind holds all the pods that are in the scheduler in the preBind phase + podsInPreBind := frameworkruntime.NewPodsInPreBindMap() + var resourceClaimCache *assumecache.AssumeCache var resourceSliceTracker *resourceslicetracker.Tracker var draManager fwk.SharedDRAManager @@ -365,6 +368,7 @@ func New(ctx context.Context, frameworkruntime.WithExtenders(extenders), frameworkruntime.WithMetricsRecorder(metricsRecorder), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(podsInPreBind), frameworkruntime.WithAPIDispatcher(apiDispatcher), frameworkruntime.WithSharedCSIManager(sharedCSIManager), frameworkruntime.WithWorkloadManager(workloadManager), diff --git a/staging/src/k8s.io/kube-scheduler/framework/interface.go b/staging/src/k8s.io/kube-scheduler/framework/interface.go index 4d76e4d7ab2..a3dd45d5af9 100644 --- a/staging/src/k8s.io/kube-scheduler/framework/interface.go +++ b/staging/src/k8s.io/kube-scheduler/framework/interface.go @@ -332,6 +332,19 @@ type WaitingPod interface { Preempt(pluginName, msg string) bool } +// PodInPreBind represents a pod currently in preBind phase. +type PodInPreBind interface { + // CancelPod cancels the context attached to a goroutine running binding cycle of this pod + // if the pod is not marked as prebound. + // Returns true if the cancel was successfully run. + CancelPod(reason string) bool + + // MarkPrebound marks the pod as prebound, making it impossible to cancel the context of binding cycle + // via PodInPreBind + // Returns false if the context was already canceled. + MarkPrebound() bool +} + // PreFilterResult wraps needed info for scheduler framework to act upon PreFilter phase. type PreFilterResult struct { // The set of nodes that should be considered downstream; if nil then @@ -739,6 +752,15 @@ type Handle interface { // The return value indicates if the pod is waiting or not. RejectWaitingPod(uid types.UID) bool + // AddPodInPreBind adds a pod to the pods in preBind list. + AddPodInPreBind(uid types.UID, cancel context.CancelCauseFunc) + + // GetPodInPreBind returns a pod that is in the binding cycle but before it is bound given its UID. + GetPodInPreBind(uid types.UID) PodInPreBind + + // RemovePodInPreBind removes a pod from the pods in preBind list. + RemovePodInPreBind(uid types.UID) + // ClientSet returns a kubernetes clientSet. ClientSet() clientset.Interface diff --git a/test/integration/scheduler/preemption/preemption_test.go b/test/integration/scheduler/preemption/preemption_test.go index 8a3b621a35a..650c7094bfe 100644 --- a/test/integration/scheduler/preemption/preemption_test.go +++ b/test/integration/scheduler/preemption/preemption_test.go @@ -1805,3 +1805,203 @@ func TestPreemptionRespectsWaitingPod(t *testing.T) { t.Fatalf("Preemptor should be scheduled on big-node, but was scheduled on %s", p.Spec.NodeName) } } + +type perPodBlockingPlugin struct { + shouldBlock bool + blocked chan struct{} + released chan struct{} +} + +// blockingPreBindPlugin is a PreBindPlugin that blocks until a signal is received. +type blockingPreBindPlugin struct { + podToChannels map[string]*perPodBlockingPlugin + handle fwk.Handle +} + +const blockingPreBindPluginName = "blocking-prebind-plugin" + +var _ fwk.PreBindPlugin = &blockingPreBindPlugin{} + +func newBlockingPreBindPlugin(_ context.Context, _ runtime.Object, h fwk.Handle) (fwk.Plugin, error) { + return &blockingPreBindPlugin{ + podToChannels: make(map[string]*perPodBlockingPlugin), + handle: h, + }, nil +} + +func (pl *blockingPreBindPlugin) Name() string { + return blockingPreBindPluginName +} + +func (pl *blockingPreBindPlugin) PreBind(ctx context.Context, _ fwk.CycleState, pod *v1.Pod, _ string) *fwk.Status { + podBlocks, ok := pl.podToChannels[pod.Name] + if !ok { + return fwk.NewStatus(fwk.Error, "pod was not prepared in test case") + } + if !podBlocks.shouldBlock { + return nil + } + + close(podBlocks.blocked) + podBlocks.shouldBlock = false + select { + case <-podBlocks.released: + return nil + case <-ctx.Done(): + return fwk.AsStatus(ctx.Err()) + } +} + +func (pl *blockingPreBindPlugin) PreBindPreFlight(ctx context.Context, state fwk.CycleState, p *v1.Pod, nodeName string) (*fwk.PreBindPreFlightResult, *fwk.Status) { + return &fwk.PreBindPreFlightResult{}, nil +} + +func TestPreemptionRespectsBindingPod(t *testing.T) { + // 1. Create a "blocking" prebind plugin that signals when it's running and waits for a specific close. + // 2. Schedule a low-priority pod (victim) that hits this plugin. + // 3. While victim is blocked in PreBind, add a small node and schedule a high-priority pod (preemptor) that fits only on a bigger node. + // 4. Verify that: + // - preemptor takes place on the bigger node + // - victim is NOT deleted, it's rescheduled on to a smaller node + + // Create a node with resources for only one pod. + bigNode := st.MakeNode().Name("big-node").Capacity(map[v1.ResourceName]string{ + v1.ResourceCPU: "2", + v1.ResourceMemory: "2Gi", + }).Obj() + // Victim requires full node resources. + victim := st.MakePod().Name("victim").Priority(lowPriority).Req(map[v1.ResourceName]string{v1.ResourceCPU: "1", v1.ResourceMemory: "1Gi"}).Obj() + // Preemptor also requires full node resources. + preemptor := st.MakePod().Name("preemptor").Priority(highPriority).Req(map[v1.ResourceName]string{v1.ResourceCPU: "1.5", v1.ResourceMemory: "1.5Gi"}).Obj() + + // Register the blocking plugin. + var plugin *blockingPreBindPlugin + registry := make(frameworkruntime.Registry) + err := registry.Register(blockingPreBindPluginName, func(ctx context.Context, obj runtime.Object, fh fwk.Handle) (fwk.Plugin, error) { + pl, err := newBlockingPreBindPlugin(ctx, obj, fh) + if err == nil { + plugin = pl.(*blockingPreBindPlugin) + } + return pl, err + }) + if err != nil { + t.Fatalf("Error registering plugin: %v", err) + } + + cfg := configtesting.V1ToInternalWithDefaults(t, configv1.KubeSchedulerConfiguration{ + Profiles: []configv1.KubeSchedulerProfile{{ + SchedulerName: ptr.To(v1.DefaultSchedulerName), + Plugins: &configv1.Plugins{ + PreBind: configv1.PluginSet{ + Enabled: []configv1.Plugin{ + {Name: blockingPreBindPluginName}, + }, + }, + }, + }}, + }) + + testCtx := testutils.InitTestSchedulerWithOptions(t, + testutils.InitTestAPIServer(t, "preemption-binding", nil), + 0, + scheduler.WithProfiles(cfg.Profiles...), + scheduler.WithFrameworkOutOfTreeRegistry(registry)) + testutils.SyncSchedulerInformerFactory(testCtx) + go testCtx.Scheduler.Run(testCtx.Ctx) + + victimBlockingPlugin := &perPodBlockingPlugin{ + shouldBlock: true, + blocked: make(chan struct{}), + released: make(chan struct{}), + } + plugin.podToChannels[victim.Name] = victimBlockingPlugin + plugin.podToChannels[preemptor.Name] = &perPodBlockingPlugin{ + shouldBlock: false, + blocked: make(chan struct{}), + released: make(chan struct{}), + } + + cs := testCtx.ClientSet + + if _, err := createNode(cs, bigNode); err != nil { + t.Fatalf("Error creating node: %v", err) + } + + // 1. Run victim. + t.Logf("Creating victim pod") + victim, err = cs.CoreV1().Pods(testCtx.NS.Name).Create(testCtx.Ctx, victim, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Error creating victim: %v", err) + } + + // Wait for victim to reach PreBind. + t.Logf("Waiting for victim to reach PreBind") + select { + case <-victimBlockingPlugin.blocked: + t.Logf("Victim reached PreBind") + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("Timed out waiting for victim to reach PreBind") + } + + // 2. Add a small node that will fit victim once its preempted. + smallNode := st.MakeNode().Name("small-node").Capacity(map[v1.ResourceName]string{ + v1.ResourceCPU: "1", + v1.ResourceMemory: "1Gi", + }).Obj() + if _, err := createNode(cs, smallNode); err != nil { + t.Fatalf("Error creating node: %v", err) + } + + // 3. Run preemptor pod. + t.Logf("Creating preemptor pod") + preemptor, err = cs.CoreV1().Pods(testCtx.NS.Name).Create(testCtx.Ctx, preemptor, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Error creating preemptor: %v", err) + } + + // 3. Wait for preemptor to be scheduled (or at least nominated) and Check victim + // Preemptor should eventually be scheduled or cause victim preemption. + // Since victim is in PreBind (Binding Cycle), Preemptor's preemption logic (PostFilter) should find it. + // It should call CancelPod() on the victim's BindingPod, causing it to go to backoff queue. + // The victim pod should NOT be deleted from API server. + // Instead it should be rescheduled onto a smaller node. + err = wait.PollUntilContextTimeout(testCtx.Ctx, 100*time.Millisecond, 10*time.Second, false, func(ctx context.Context) (bool, error) { + // Check if victim is deleted + v, err := cs.CoreV1().Pods(testCtx.NS.Name).Get(ctx, victim.Name, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + return false, fmt.Errorf("victim pod was deleted") + } + return false, err + } + // Check if victim was rescheduled + _, cond := podutil.GetPodCondition(&v.Status, v1.PodScheduled) + if cond != nil && cond.Status == v1.ConditionTrue { + return true, nil + } + return false, nil + }) + if err != nil { + t.Fatalf("Failed waiting for victim validation: %v", err) + } + + // 6. Check that preemptor and victim are scheduled on expected nodes: victim on a small node and preemptor on a big node. + v, err := cs.CoreV1().Pods(testCtx.NS.Name).Get(testCtx.Ctx, victim.Name, metav1.GetOptions{}) + if err != nil { + t.Fatalf("Error getting victim: %v", err) + } + if v.Spec.NodeName != "small-node" { + t.Fatalf("Victim should be scheduled on node2, but was scheduled on %s", v.Spec.NodeName) + } + + p, err := cs.CoreV1().Pods(testCtx.NS.Name).Get(testCtx.Ctx, preemptor.Name, metav1.GetOptions{}) + if err != nil { + t.Fatalf("Error getting preemptor: %v", err) + } + if p.Spec.NodeName != "big-node" { + t.Fatalf("Preemptor should be scheduled on big-node, but was scheduled on %s", v.Spec.NodeName) + } + + // Start a goroutine to release the plugin just in case, ensuring clean teardown. + close(victimBlockingPlugin.released) +}