diff --git a/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go b/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go index 41592e22717..27de7964105 100644 --- a/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go +++ b/pkg/scheduler/framework/plugins/defaultpreemption/default_preemption_test.go @@ -477,6 +477,7 @@ func TestPostFilter(t *testing.T) { frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(tt.pods, tt.nodes)), frameworkruntime.WithLogger(logger), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), ) if err != nil { t.Fatal(err) @@ -2247,6 +2248,7 @@ func TestPreempt(t *testing.T) { frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(testPods, nodes)), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithLogger(logger), frameworkruntime.WithPodActivator(&fakePodActivator{}), ) diff --git a/pkg/scheduler/framework/preemption/executor.go b/pkg/scheduler/framework/preemption/executor.go index 0553c11b5df..739e9d7e299 100644 --- a/pkg/scheduler/framework/preemption/executor.go +++ b/pkg/scheduler/framework/preemption/executor.go @@ -77,6 +77,7 @@ func newExecutor(fh fwk.Handle) *Executor { logger := klog.FromContext(ctx) skipAPICall := false + eventMessage := fmt.Sprintf("Preempted by pod %v on node %v", preemptor.UID, c.Name()) // If the victim is a WaitingPod, try to preempt it without a delete call (victim will go back to backoff queue). // Otherwise we should delete the victim. if waitingPod := e.fh.GetWaitingPod(victim.UID); waitingPod != nil { @@ -84,6 +85,14 @@ func newExecutor(fh fwk.Handle) *Executor { logger.V(2).Info("Preemptor pod preempted a waiting pod", "preemptor", klog.KObj(preemptor), "waitingPod", klog.KObj(victim), "node", c.Name()) skipAPICall = true } + } else if podInPreBind := e.fh.GetPodInPreBind(victim.UID); podInPreBind != nil { + // If the victim is in the preBind cancel the binding process. + if podInPreBind.CancelPod(fmt.Sprintf("preempted by %s", pluginName)) { + logger.V(2).Info("Preemptor pod rejected a pod in preBind", "preemptor", klog.KObj(preemptor), "podInPreBind", klog.KObj(victim), "node", c.Name()) + skipAPICall = true + } else { + logger.V(5).Info("Failed to reject a pod in preBind, falling back to deletion via api call", "preemptor", klog.KObj(preemptor), "podInPreBind", klog.KObj(victim), "node", c.Name()) + } } if !skipAPICall { condition := &v1.PodCondition{ @@ -114,9 +123,11 @@ func newExecutor(fh fwk.Handle) *Executor { return nil } logger.V(2).Info("Preemptor Pod preempted victim Pod", "preemptor", klog.KObj(preemptor), "victim", klog.KObj(victim), "node", c.Name()) + } else { + eventMessage += " (in kube-scheduler memory)." } - fh.EventRecorder().Eventf(victim, preemptor, v1.EventTypeNormal, "Preempted", "Preempting", "Preempted by pod %v on node %v", preemptor.UID, c.Name()) + fh.EventRecorder().Eventf(victim, preemptor, v1.EventTypeNormal, "Preempted", "Preempting", eventMessage) return nil } diff --git a/pkg/scheduler/framework/preemption/executor_test.go b/pkg/scheduler/framework/preemption/executor_test.go index 908a92cff8e..0efef7c9274 100644 --- a/pkg/scheduler/framework/preemption/executor_test.go +++ b/pkg/scheduler/framework/preemption/executor_test.go @@ -47,6 +47,7 @@ import ( apidispatcher "k8s.io/kubernetes/pkg/scheduler/backend/api_dispatcher" internalcache "k8s.io/kubernetes/pkg/scheduler/backend/cache" internalqueue "k8s.io/kubernetes/pkg/scheduler/backend/queue" + "k8s.io/kubernetes/pkg/scheduler/framework" apicalls "k8s.io/kubernetes/pkg/scheduler/framework/api_calls" "k8s.io/kubernetes/pkg/scheduler/framework/plugins/defaultbinder" "k8s.io/kubernetes/pkg/scheduler/framework/plugins/queuesort" @@ -589,6 +590,7 @@ func TestPrepareCandidate(t *testing.T) { frameworkruntime.WithLogger(logger), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(tt.testPods, nodes)), frameworkruntime.WithPodNominator(nominator), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), @@ -817,6 +819,7 @@ func TestPrepareCandidateAsyncSetsPreemptingSets(t *testing.T) { frameworkruntime.WithLogger(logger), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(testPods, nodes)), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), frameworkruntime.WithPodNominator(internalqueue.NewSchedulingQueue(nil, informerFactory)), @@ -1054,6 +1057,7 @@ func TestAsyncPreemptionFailure(t *testing.T) { frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), frameworkruntime.WithPodActivator(fakeActivator), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(snapshotPods, []*v1.Node{st.MakeNode().Name(node1Name).Obj()})), ) if err != nil { @@ -1191,3 +1195,127 @@ func TestRemoveNominatedNodeName(t *testing.T) { } } } + +func TestPreemptPod(t *testing.T) { + preemptorPod := st.MakePod().Name("p").UID("p").Priority(highPriority).Obj() + victimPod := st.MakePod().Name("v").UID("v").Priority(midPriority).Obj() + + tests := []struct { + name string + addVictimToPrebind bool + addVictimToWaiting bool + expectCancel bool + expectedActions []string + }{ + { + name: "victim is in preBind, context should be cancelled", + addVictimToPrebind: true, + addVictimToWaiting: false, + expectCancel: true, + expectedActions: []string{}, + }, + { + name: "victim is in waiting pods, it should be rejected (no calls to apiserver)", + addVictimToPrebind: false, + addVictimToWaiting: true, + expectCancel: false, + expectedActions: []string{}, + }, + { + name: "victim is not in waiting/preBind pods, pod should be deleted", + addVictimToPrebind: false, + addVictimToWaiting: false, + expectCancel: false, + expectedActions: []string{"patch", "delete"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podsInPreBind := frameworkruntime.NewPodsInPreBindMap() + waitingPods := frameworkruntime.NewWaitingPodsMap() + registeredPlugins := append([]tf.RegisterPluginFunc{ + tf.RegisterQueueSortPlugin(queuesort.Name, queuesort.New)}, + tf.RegisterBindPlugin(defaultbinder.Name, defaultbinder.New), + tf.RegisterPermitPlugin(waitingPermitPluginName, newWaitingPermitPlugin), + ) + objs := []runtime.Object{preemptorPod, victimPod} + cs := clientsetfake.NewClientset(objs...) + informerFactory := informers.NewSharedInformerFactory(cs, 0) + eventBroadcaster := events.NewBroadcaster(&events.EventSinkImpl{Interface: cs.EventsV1()}) + logger, ctx := ktesting.NewTestContext(t) + + fwk, err := tf.NewFramework( + ctx, + registeredPlugins, "", + frameworkruntime.WithClientSet(cs), + frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot([]*v1.Pod{}, []*v1.Node{})), + frameworkruntime.WithInformerFactory(informerFactory), + frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(podsInPreBind), + frameworkruntime.WithLogger(logger), + frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, "test-scheduler")), + ) + if err != nil { + t.Fatal(err) + } + var victimCtx context.Context + var cancel context.CancelCauseFunc + if tt.addVictimToPrebind { + victimCtx, cancel = context.WithCancelCause(context.Background()) + fwk.AddPodInPreBind(victimPod.UID, cancel) + } + if tt.addVictimToWaiting { + status := fwk.RunPermitPlugins(ctx, framework.NewCycleState(), victimPod, "fake-node") + if !status.IsWait() { + t.Fatalf("Failed to add a pod to waiting list") + } + } + pe := NewEvaluator("FakePreemptionScorePostFilter", fwk, &FakePreemptionScorePostFilterPlugin{}, false) + + err = pe.PreemptPod(ctx, &candidate{}, preemptorPod, victimPod, "test-plugin") + if err != nil { + t.Fatal(err) + } + if tt.expectCancel { + if victimCtx.Err() == nil { + t.Errorf("Context of a binding pod should be cancelled") + } + } else { + if victimCtx != nil && victimCtx.Err() != nil { + t.Errorf("Context of a normal pod should not be cancelled") + } + } + + // check if the API call was made + actions := cs.Actions() + if len(actions) != len(tt.expectedActions) { + t.Errorf("Expected %d actions, but got %d", len(tt.expectedActions), len(actions)) + } + for i, action := range actions { + if action.GetVerb() != tt.expectedActions[i] { + t.Errorf("Expected action %s, but got %s", tt.expectedActions[i], action.GetVerb()) + } + } + }) + } +} + +// waitingPermitPlugin is a PermitPlugin that always returns Wait. +type waitingPermitPlugin struct{} + +var _ fwk.PermitPlugin = &waitingPermitPlugin{} + +const waitingPermitPluginName = "waitingPermitPlugin" + +func newWaitingPermitPlugin(_ context.Context, _ runtime.Object, _ fwk.Handle) (fwk.Plugin, error) { + return &waitingPermitPlugin{}, nil +} + +func (pl *waitingPermitPlugin) Name() string { + return waitingPermitPluginName +} + +func (pl *waitingPermitPlugin) Permit(ctx context.Context, _ fwk.CycleState, _ *v1.Pod, nodeName string) (*fwk.Status, time.Duration) { + return fwk.NewStatus(fwk.Wait, ""), 10 * time.Second +} diff --git a/pkg/scheduler/framework/runtime/framework.go b/pkg/scheduler/framework/runtime/framework.go index 88f659389dd..fdf9ea7e8bf 100644 --- a/pkg/scheduler/framework/runtime/framework.go +++ b/pkg/scheduler/framework/runtime/framework.go @@ -58,6 +58,7 @@ type frameworkImpl struct { registry Registry snapshotSharedLister fwk.SharedLister waitingPods *waitingPodsMap + podsInPreBind *podsInPreBindMap scorePluginWeight map[string]int preEnqueuePlugins []fwk.PreEnqueuePlugin enqueueExtensions []fwk.EnqueueExtensions @@ -153,6 +154,7 @@ type frameworkOptions struct { captureProfile CaptureProfile parallelizer parallelize.Parallelizer waitingPods *waitingPodsMap + podsInPreBind *podsInPreBindMap apiDispatcher *apidispatcher.APIDispatcher workloadManager fwk.WorkloadManager logger *klog.Logger @@ -285,6 +287,13 @@ func WithWaitingPods(wp *waitingPodsMap) Option { } } +// WithPodsInPreBind sets podsInPreBind for the scheduling frameworkImpl. +func WithPodsInPreBind(bp *podsInPreBindMap) Option { + return func(o *frameworkOptions) { + o.podsInPreBind = bp + } +} + // WithLogger overrides the default logger from k8s.io/klog. func WithLogger(logger klog.Logger) Option { return func(o *frameworkOptions) { @@ -323,6 +332,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler sharedCSIManager: options.sharedCSIManager, scorePluginWeight: make(map[string]int), waitingPods: options.waitingPods, + podsInPreBind: options.podsInPreBind, clientSet: options.clientSet, kubeConfig: options.kubeConfig, eventRecorder: options.eventRecorder, @@ -1456,6 +1466,10 @@ func (f *frameworkImpl) RunPreBindPlugins(ctx context.Context, state fwk.CycleSt return plStatus } err := plStatus.AsError() + if errors.Is(err, context.Canceled) { + err = context.Cause(ctx) + } + logger.Error(err, "Plugin failed", "plugin", pl.Name(), "pod", klog.KObj(pod), "node", nodeName) return fwk.AsStatus(fmt.Errorf("running PreBind plugin %q: %w", pl.Name(), err)) } @@ -1894,6 +1908,24 @@ func (f *frameworkImpl) RejectWaitingPod(uid types.UID) bool { return false } +// AddPodInPreBind adds a pod to the pods in preBind list. +func (f *frameworkImpl) AddPodInPreBind(uid types.UID, cancel context.CancelCauseFunc) { + f.podsInPreBind.add(uid, cancel) +} + +// GetPodInPreBind returns a pod that is in the binding cycle but before it is bound given its UID. +func (f *frameworkImpl) GetPodInPreBind(uid types.UID) fwk.PodInPreBind { + if bp := f.podsInPreBind.get(uid); bp != nil { + return bp + } + return nil +} + +// RemovePodInPreBind removes a pod from the pods in preBind list. +func (f *frameworkImpl) RemovePodInPreBind(uid types.UID) { + f.podsInPreBind.remove(uid) +} + // HasFilterPlugins returns true if at least one filter plugin is defined. func (f *frameworkImpl) HasFilterPlugins() bool { return len(f.filterPlugins) > 0 diff --git a/pkg/scheduler/framework/runtime/pods_in_prebind_map.go b/pkg/scheduler/framework/runtime/pods_in_prebind_map.go new file mode 100644 index 00000000000..41f369e51da --- /dev/null +++ b/pkg/scheduler/framework/runtime/pods_in_prebind_map.go @@ -0,0 +1,98 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtime + +import ( + "context" + "errors" + "sync" + + "k8s.io/apimachinery/pkg/types" + fwk "k8s.io/kube-scheduler/framework" +) + +// podsInPreBindMap is a thread-safe map of pods currently in the preBind phase. +type podsInPreBindMap struct { + pods map[types.UID]*podInPreBind + mu sync.RWMutex +} + +// NewPodsInPreBindMap creates an empty podsInPreBindMap. +func NewPodsInPreBindMap() *podsInPreBindMap { + return &podsInPreBindMap{ + pods: make(map[types.UID]*podInPreBind), + } +} + +// get returns a pod from the map if it exists. +func (pbm *podsInPreBindMap) get(uid types.UID) *podInPreBind { + pbm.mu.RLock() + defer pbm.mu.RUnlock() + return pbm.pods[uid] +} + +// add adds a pod to map, overwriting existing one. +func (pbm *podsInPreBindMap) add(uid types.UID, cancel context.CancelCauseFunc) { + pbm.mu.Lock() + defer pbm.mu.Unlock() + pbm.pods[uid] = &podInPreBind{cancel: cancel} +} + +// remove removes a pod from the map. +func (pbm *podsInPreBindMap) remove(uid types.UID) { + pbm.mu.Lock() + defer pbm.mu.Unlock() + delete(pbm.pods, uid) +} + +var _ fwk.PodInPreBind = &podInPreBind{} + +// podInPreBind describes a pod in the preBind phase, before the bind was called for a pod. +type podInPreBind struct { + finished bool + canceled bool + cancel context.CancelCauseFunc + mu sync.Mutex +} + +// CancelPod cancels the context running the preBind phase +// for a given pod. +func (bp *podInPreBind) CancelPod(message string) bool { + bp.mu.Lock() + defer bp.mu.Unlock() + if bp.finished { + return false + } + if !bp.canceled { + bp.cancel(errors.New(message)) + } + bp.canceled = true + return true +} + +// MarkPrebound marks the pod as finished with preBind phase +// of binding cycle, making it impossible to cancel +// the binding cycle for it. +func (bp *podInPreBind) MarkPrebound() bool { + bp.mu.Lock() + defer bp.mu.Unlock() + if bp.canceled { + return false + } + bp.finished = true + return true +} diff --git a/pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go b/pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go new file mode 100644 index 00000000000..add5800f983 --- /dev/null +++ b/pkg/scheduler/framework/runtime/pods_in_prebind_map_test.go @@ -0,0 +1,121 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtime + +import ( + "context" + "testing" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2/ktesting" +) + +const ( + testPodUID = types.UID("pod-1") +) + +func TestPodsInBindingMap(t *testing.T) { + tests := []struct { + name string + scenario func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) + }{ + { + name: "add and get", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + if pod == nil { + t.Fatalf("expected pod to be in map") + } + }, + }, + { + name: "add, remove and get", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + m.remove(testPodUID) + if m.get(testPodUID) != nil { + t.Errorf("expected pod to be removed from map") + } + }, + }, + { + name: "Verify CancelPod logic", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + + // First cancel should succeed + if !pod.CancelPod("test cancel") { + t.Errorf("First CancelPod should return true") + } + if ctx.Err() == nil { + t.Errorf("Context should be cancelled") + } + + // Second cancel should also return true + if !pod.CancelPod("test cancel") { + t.Errorf("Second CancelPod should return true") + } + }, + }, + { + name: "Verify MarkBound logic", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + + if !pod.MarkPrebound() { + t.Errorf("MarkBound should return true for fresh pod") + } + if !pod.finished { + t.Errorf("finished should be true") + } + + // Try to cancel after binding + if pod.CancelPod("test cancel") { + t.Errorf("CancelPod should return false after MarkBound") + } + if ctx.Err() != nil { + t.Errorf("Context should NOT be cancelled if MarkBound succeeded first") + } + }, + }, + { + name: "Verify MarkBound fails on cancelled pod", + scenario: func(t *testing.T, ctx context.Context, cancel context.CancelCauseFunc, m *podsInPreBindMap) { + m.add(testPodUID, cancel) + pod := m.get(testPodUID) + + pod.CancelPod("test cancel") + if pod.MarkPrebound() { + t.Errorf("MarkBound should return false on cancelled pod") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := NewPodsInPreBindMap() + _, ctx := ktesting.NewTestContext(t) + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + tt.scenario(t, ctx, cancel, m) + }) + } +} diff --git a/pkg/scheduler/schedule_one.go b/pkg/scheduler/schedule_one.go index db5bba104fd..064e37b58b4 100644 --- a/pkg/scheduler/schedule_one.go +++ b/pkg/scheduler/schedule_one.go @@ -399,8 +399,9 @@ func (sched *Scheduler) bindingCycle( assumedPod := assumedPodInfo.Pod + var preFlightStatus *fwk.Status if sched.nominatedNodeNameForExpectationEnabled { - preFlightStatus := schedFramework.RunPreBindPreFlights(ctx, state, assumedPod, scheduleResult.SuggestedHost) + preFlightStatus = schedFramework.RunPreBindPreFlights(ctx, state, assumedPod, scheduleResult.SuggestedHost) if preFlightStatus.Code() == fwk.Error || // Unschedulable status is not supported in PreBindPreFlight and hence we regard it as an error. preFlightStatus.IsRejected() { @@ -444,11 +445,26 @@ func (sched *Scheduler) bindingCycle( // we can free the cluster events stored in the scheduling queue sooner, which is worth for busy clusters memory consumption wise. sched.SchedulingQueue.Done(assumedPod.UID) + // If we are going to run prebind plugins we put the pod in binding map to optimize preemption. + if preFlightStatus.IsSuccess() { + var podInPreBindCancel context.CancelCauseFunc + ctx, podInPreBindCancel = context.WithCancelCause(ctx) + defer podInPreBindCancel(nil) + defer schedFramework.RemovePodInPreBind(assumedPod.UID) + schedFramework.AddPodInPreBind(assumedPod.UID, podInPreBindCancel) + } // Run "prebind" plugins. if status := schedFramework.RunPreBindPlugins(ctx, state, assumedPod, scheduleResult.SuggestedHost); !status.IsSuccess() { return status } + // Verify that pod was not preempted during prebinding. + bindingPod := schedFramework.GetPodInPreBind(assumedPod.UID) + if bindingPod != nil && !bindingPod.MarkPrebound() { + err := context.Cause(ctx) + return fwk.AsStatus(err) + } + // Run "bind" plugins. if status := sched.bind(ctx, schedFramework, assumedPod, scheduleResult.SuggestedHost, state); !status.IsSuccess() { return status diff --git a/pkg/scheduler/schedule_one_podgroup_test.go b/pkg/scheduler/schedule_one_podgroup_test.go index c7853355ffa..02976c06634 100644 --- a/pkg/scheduler/schedule_one_podgroup_test.go +++ b/pkg/scheduler/schedule_one_podgroup_test.go @@ -1038,10 +1038,12 @@ func TestSubmitPodGroupAlgorithmResult(t *testing.T) { }, } waitingPods := frameworkruntime.NewWaitingPodsMap() + podInPreBind := frameworkruntime.NewPodsInPreBindMap() schedFwk, err := frameworkruntime.NewFramework(ctx, registry, &profileCfg, frameworkruntime.WithClientSet(client), frameworkruntime.WithEventRecorder(events.NewFakeRecorder(100)), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(podInPreBind), ) if err != nil { t.Fatalf("Failed to create new framework: %v", err) diff --git a/pkg/scheduler/schedule_one_test.go b/pkg/scheduler/schedule_one_test.go index 88e79eef2a4..6dd19512f51 100644 --- a/pkg/scheduler/schedule_one_test.go +++ b/pkg/scheduler/schedule_one_test.go @@ -368,6 +368,30 @@ func newFakeNodeSelectorDependOnPodAnnotation(_ context.Context, _ runtime.Objec return &fakeNodeSelectorDependOnPodAnnotation{}, nil } +var blockingPreBindStartCh = make(chan struct{}, 1) + +var _ fwk.PreBindPlugin = &BlockingPreBindPlugin{} + +type BlockingPreBindPlugin struct { + handle fwk.Handle +} + +func (p *BlockingPreBindPlugin) Name() string { return "BlockingPreBindPlugin" } + +func (p *BlockingPreBindPlugin) PreBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status { + go func() { blockingPreBindStartCh <- struct{}{} }() + <-ctx.Done() + return fwk.AsStatus(ctx.Err()) +} + +func (p *BlockingPreBindPlugin) PreBindPreFlight(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) (*fwk.PreBindPreFlightResult, *fwk.Status) { + return &fwk.PreBindPreFlightResult{}, nil +} + +func NewBlockingPreBindPlugin(_ context.Context, _ runtime.Object, h fwk.Handle) (fwk.Plugin, error) { + return &BlockingPreBindPlugin{handle: h}, nil +} + type TestPlugin struct { name string } @@ -718,6 +742,10 @@ func TestSchedulerScheduleOne(t *testing.T) { asyncAPICallsEnabled *bool // If nil, the test case is run with both true and false scheduleAsPodGroup *bool + // postSchedulingCycle is run after ScheduleOne function returns + // (synchronous part of scheduling is done and binding goroutine is launched) + // and before blocking execution on waiting for event with eventReason + postSchedulingCycle func(context.Context, *Scheduler) } table := []testItem{ { @@ -898,6 +926,29 @@ func TestSchedulerScheduleOne(t *testing.T) { expectError: fmt.Errorf(`running PreBind plugin "FakePreBind": %w`, preBindErr), eventReason: "FailedScheduling", }, + { + name: "prebind pod cancelled during prebind (external preemption)", + sendPod: testPod, + registerPluginFuncs: []tf.RegisterPluginFunc{ + tf.RegisterPreBindPlugin("BlockingPreBindPlugin", NewBlockingPreBindPlugin), + }, + mockScheduleResult: scheduleResultOk, + postSchedulingCycle: func(ctx context.Context, sched *Scheduler) { + <-blockingPreBindStartCh // Wait for plugin to start + // Trigger preemption from "outside" + if bp := sched.Profiles[testSchedulerName].GetPodInPreBind(testPod.UID); bp != nil { + bp.CancelPod("context cancelled externally") + } + }, + nominatedNodeNameForExpectationEnabled: ptr.To(true), + expectAssumedPod: assignedTestPod, + expectErrorPod: assignedTestPod, + expectForgetPod: assignedTestPod, + expectNominatedNodeName: testNode.Name, + expectPodInBackoffQ: testPod, + expectError: fmt.Errorf(`running PreBind plugin "BlockingPreBindPlugin": context cancelled externally`), + eventReason: "FailedScheduling", + }, { name: "binding failed", sendPod: testPod, @@ -1064,6 +1115,7 @@ func TestSchedulerScheduleOne(t *testing.T) { frameworkruntime.WithAPIDispatcher(apiDispatcher), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, testSchedulerName)), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithWorkloadManager(wm), ) @@ -1115,6 +1167,10 @@ func TestSchedulerScheduleOne(t *testing.T) { sched.nodeInfoSnapshot = internalcache.NewEmptySnapshot() sched.ScheduleOne(ctx) + if item.postSchedulingCycle != nil { + item.postSchedulingCycle(ctx, sched) + } + if item.podToAdmit != nil { for { if waitingPod := sched.Profiles[testSchedulerName].GetWaitingPod(item.podToAdmit.pod); waitingPod != nil { @@ -1684,6 +1740,7 @@ func TestScheduleOneMarksPodAsProcessedBeforePreBind(t *testing.T) { frameworkruntime.WithAPIDispatcher(apiDispatcher), frameworkruntime.WithEventRecorder(eventBroadcaster.NewRecorder(scheme.Scheme, testSchedulerName)), frameworkruntime.WithWaitingPods(frameworkruntime.NewWaitingPodsMap()), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), ) if err != nil { t.Fatal(err) @@ -4562,6 +4619,7 @@ func setupTestScheduler(ctx context.Context, t *testing.T, client clientset.Inte frameworkruntime.WithInformerFactory(informerFactory), frameworkruntime.WithPodNominator(schedulingQueue), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(frameworkruntime.NewPodsInPreBindMap()), frameworkruntime.WithSnapshotSharedLister(snapshot), ) if apiDispatcher != nil { diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 16c7da3802a..27c546c3e9b 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -318,6 +318,9 @@ func New(ctx context.Context, // waitingPods holds all the pods that are in the scheduler and waiting in the permit stage waitingPods := frameworkruntime.NewWaitingPodsMap() + // podsInPreBind holds all the pods that are in the scheduler in the preBind phase + podsInPreBind := frameworkruntime.NewPodsInPreBindMap() + var resourceClaimCache *assumecache.AssumeCache var resourceSliceTracker *resourceslicetracker.Tracker var draManager fwk.SharedDRAManager @@ -365,6 +368,7 @@ func New(ctx context.Context, frameworkruntime.WithExtenders(extenders), frameworkruntime.WithMetricsRecorder(metricsRecorder), frameworkruntime.WithWaitingPods(waitingPods), + frameworkruntime.WithPodsInPreBind(podsInPreBind), frameworkruntime.WithAPIDispatcher(apiDispatcher), frameworkruntime.WithSharedCSIManager(sharedCSIManager), frameworkruntime.WithWorkloadManager(workloadManager), diff --git a/staging/src/k8s.io/kube-scheduler/framework/interface.go b/staging/src/k8s.io/kube-scheduler/framework/interface.go index 4d76e4d7ab2..a3dd45d5af9 100644 --- a/staging/src/k8s.io/kube-scheduler/framework/interface.go +++ b/staging/src/k8s.io/kube-scheduler/framework/interface.go @@ -332,6 +332,19 @@ type WaitingPod interface { Preempt(pluginName, msg string) bool } +// PodInPreBind represents a pod currently in preBind phase. +type PodInPreBind interface { + // CancelPod cancels the context attached to a goroutine running binding cycle of this pod + // if the pod is not marked as prebound. + // Returns true if the cancel was successfully run. + CancelPod(reason string) bool + + // MarkPrebound marks the pod as prebound, making it impossible to cancel the context of binding cycle + // via PodInPreBind + // Returns false if the context was already canceled. + MarkPrebound() bool +} + // PreFilterResult wraps needed info for scheduler framework to act upon PreFilter phase. type PreFilterResult struct { // The set of nodes that should be considered downstream; if nil then @@ -739,6 +752,15 @@ type Handle interface { // The return value indicates if the pod is waiting or not. RejectWaitingPod(uid types.UID) bool + // AddPodInPreBind adds a pod to the pods in preBind list. + AddPodInPreBind(uid types.UID, cancel context.CancelCauseFunc) + + // GetPodInPreBind returns a pod that is in the binding cycle but before it is bound given its UID. + GetPodInPreBind(uid types.UID) PodInPreBind + + // RemovePodInPreBind removes a pod from the pods in preBind list. + RemovePodInPreBind(uid types.UID) + // ClientSet returns a kubernetes clientSet. ClientSet() clientset.Interface diff --git a/test/integration/scheduler/preemption/preemption_test.go b/test/integration/scheduler/preemption/preemption_test.go index 8a3b621a35a..650c7094bfe 100644 --- a/test/integration/scheduler/preemption/preemption_test.go +++ b/test/integration/scheduler/preemption/preemption_test.go @@ -1805,3 +1805,203 @@ func TestPreemptionRespectsWaitingPod(t *testing.T) { t.Fatalf("Preemptor should be scheduled on big-node, but was scheduled on %s", p.Spec.NodeName) } } + +type perPodBlockingPlugin struct { + shouldBlock bool + blocked chan struct{} + released chan struct{} +} + +// blockingPreBindPlugin is a PreBindPlugin that blocks until a signal is received. +type blockingPreBindPlugin struct { + podToChannels map[string]*perPodBlockingPlugin + handle fwk.Handle +} + +const blockingPreBindPluginName = "blocking-prebind-plugin" + +var _ fwk.PreBindPlugin = &blockingPreBindPlugin{} + +func newBlockingPreBindPlugin(_ context.Context, _ runtime.Object, h fwk.Handle) (fwk.Plugin, error) { + return &blockingPreBindPlugin{ + podToChannels: make(map[string]*perPodBlockingPlugin), + handle: h, + }, nil +} + +func (pl *blockingPreBindPlugin) Name() string { + return blockingPreBindPluginName +} + +func (pl *blockingPreBindPlugin) PreBind(ctx context.Context, _ fwk.CycleState, pod *v1.Pod, _ string) *fwk.Status { + podBlocks, ok := pl.podToChannels[pod.Name] + if !ok { + return fwk.NewStatus(fwk.Error, "pod was not prepared in test case") + } + if !podBlocks.shouldBlock { + return nil + } + + close(podBlocks.blocked) + podBlocks.shouldBlock = false + select { + case <-podBlocks.released: + return nil + case <-ctx.Done(): + return fwk.AsStatus(ctx.Err()) + } +} + +func (pl *blockingPreBindPlugin) PreBindPreFlight(ctx context.Context, state fwk.CycleState, p *v1.Pod, nodeName string) (*fwk.PreBindPreFlightResult, *fwk.Status) { + return &fwk.PreBindPreFlightResult{}, nil +} + +func TestPreemptionRespectsBindingPod(t *testing.T) { + // 1. Create a "blocking" prebind plugin that signals when it's running and waits for a specific close. + // 2. Schedule a low-priority pod (victim) that hits this plugin. + // 3. While victim is blocked in PreBind, add a small node and schedule a high-priority pod (preemptor) that fits only on a bigger node. + // 4. Verify that: + // - preemptor takes place on the bigger node + // - victim is NOT deleted, it's rescheduled on to a smaller node + + // Create a node with resources for only one pod. + bigNode := st.MakeNode().Name("big-node").Capacity(map[v1.ResourceName]string{ + v1.ResourceCPU: "2", + v1.ResourceMemory: "2Gi", + }).Obj() + // Victim requires full node resources. + victim := st.MakePod().Name("victim").Priority(lowPriority).Req(map[v1.ResourceName]string{v1.ResourceCPU: "1", v1.ResourceMemory: "1Gi"}).Obj() + // Preemptor also requires full node resources. + preemptor := st.MakePod().Name("preemptor").Priority(highPriority).Req(map[v1.ResourceName]string{v1.ResourceCPU: "1.5", v1.ResourceMemory: "1.5Gi"}).Obj() + + // Register the blocking plugin. + var plugin *blockingPreBindPlugin + registry := make(frameworkruntime.Registry) + err := registry.Register(blockingPreBindPluginName, func(ctx context.Context, obj runtime.Object, fh fwk.Handle) (fwk.Plugin, error) { + pl, err := newBlockingPreBindPlugin(ctx, obj, fh) + if err == nil { + plugin = pl.(*blockingPreBindPlugin) + } + return pl, err + }) + if err != nil { + t.Fatalf("Error registering plugin: %v", err) + } + + cfg := configtesting.V1ToInternalWithDefaults(t, configv1.KubeSchedulerConfiguration{ + Profiles: []configv1.KubeSchedulerProfile{{ + SchedulerName: ptr.To(v1.DefaultSchedulerName), + Plugins: &configv1.Plugins{ + PreBind: configv1.PluginSet{ + Enabled: []configv1.Plugin{ + {Name: blockingPreBindPluginName}, + }, + }, + }, + }}, + }) + + testCtx := testutils.InitTestSchedulerWithOptions(t, + testutils.InitTestAPIServer(t, "preemption-binding", nil), + 0, + scheduler.WithProfiles(cfg.Profiles...), + scheduler.WithFrameworkOutOfTreeRegistry(registry)) + testutils.SyncSchedulerInformerFactory(testCtx) + go testCtx.Scheduler.Run(testCtx.Ctx) + + victimBlockingPlugin := &perPodBlockingPlugin{ + shouldBlock: true, + blocked: make(chan struct{}), + released: make(chan struct{}), + } + plugin.podToChannels[victim.Name] = victimBlockingPlugin + plugin.podToChannels[preemptor.Name] = &perPodBlockingPlugin{ + shouldBlock: false, + blocked: make(chan struct{}), + released: make(chan struct{}), + } + + cs := testCtx.ClientSet + + if _, err := createNode(cs, bigNode); err != nil { + t.Fatalf("Error creating node: %v", err) + } + + // 1. Run victim. + t.Logf("Creating victim pod") + victim, err = cs.CoreV1().Pods(testCtx.NS.Name).Create(testCtx.Ctx, victim, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Error creating victim: %v", err) + } + + // Wait for victim to reach PreBind. + t.Logf("Waiting for victim to reach PreBind") + select { + case <-victimBlockingPlugin.blocked: + t.Logf("Victim reached PreBind") + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("Timed out waiting for victim to reach PreBind") + } + + // 2. Add a small node that will fit victim once its preempted. + smallNode := st.MakeNode().Name("small-node").Capacity(map[v1.ResourceName]string{ + v1.ResourceCPU: "1", + v1.ResourceMemory: "1Gi", + }).Obj() + if _, err := createNode(cs, smallNode); err != nil { + t.Fatalf("Error creating node: %v", err) + } + + // 3. Run preemptor pod. + t.Logf("Creating preemptor pod") + preemptor, err = cs.CoreV1().Pods(testCtx.NS.Name).Create(testCtx.Ctx, preemptor, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Error creating preemptor: %v", err) + } + + // 3. Wait for preemptor to be scheduled (or at least nominated) and Check victim + // Preemptor should eventually be scheduled or cause victim preemption. + // Since victim is in PreBind (Binding Cycle), Preemptor's preemption logic (PostFilter) should find it. + // It should call CancelPod() on the victim's BindingPod, causing it to go to backoff queue. + // The victim pod should NOT be deleted from API server. + // Instead it should be rescheduled onto a smaller node. + err = wait.PollUntilContextTimeout(testCtx.Ctx, 100*time.Millisecond, 10*time.Second, false, func(ctx context.Context) (bool, error) { + // Check if victim is deleted + v, err := cs.CoreV1().Pods(testCtx.NS.Name).Get(ctx, victim.Name, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + return false, fmt.Errorf("victim pod was deleted") + } + return false, err + } + // Check if victim was rescheduled + _, cond := podutil.GetPodCondition(&v.Status, v1.PodScheduled) + if cond != nil && cond.Status == v1.ConditionTrue { + return true, nil + } + return false, nil + }) + if err != nil { + t.Fatalf("Failed waiting for victim validation: %v", err) + } + + // 6. Check that preemptor and victim are scheduled on expected nodes: victim on a small node and preemptor on a big node. + v, err := cs.CoreV1().Pods(testCtx.NS.Name).Get(testCtx.Ctx, victim.Name, metav1.GetOptions{}) + if err != nil { + t.Fatalf("Error getting victim: %v", err) + } + if v.Spec.NodeName != "small-node" { + t.Fatalf("Victim should be scheduled on node2, but was scheduled on %s", v.Spec.NodeName) + } + + p, err := cs.CoreV1().Pods(testCtx.NS.Name).Get(testCtx.Ctx, preemptor.Name, metav1.GetOptions{}) + if err != nil { + t.Fatalf("Error getting preemptor: %v", err) + } + if p.Spec.NodeName != "big-node" { + t.Fatalf("Preemptor should be scheduled on big-node, but was scheduled on %s", v.Spec.NodeName) + } + + // Start a goroutine to release the plugin just in case, ensuring clean teardown. + close(victimBlockingPlugin.released) +}