diff --git a/pkg/scheduler/framework/preemption/podgrouppreemption_test.go b/pkg/scheduler/framework/preemption/podgrouppreemption_test.go index e04429b80dc..be8cb58874c 100644 --- a/pkg/scheduler/framework/preemption/podgrouppreemption_test.go +++ b/pkg/scheduler/framework/preemption/podgrouppreemption_test.go @@ -88,8 +88,8 @@ func TestPodGroupEvaluator_SelectVictimsOnDomain(t *testing.T) { st.MakePod().Name("p3").UID("v3").Node("node3").Priority(lowPriority).PodGroupName("pg2").Obj(), }, initPodGroups: []*schedulingapi.PodGroup{ - st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), - st.MakePodGroup().Name("pg2").UID("pg2").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), + st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Priority(lowPriority).Obj(), + st.MakePodGroup().Name("pg2").UID("pg2").DisruptionMode(schedulingapi.DisruptionModePod).Priority(lowPriority).Obj(), }, preemptor: newPodGroupPreemptor( st.MakePodGroup().Name("preemptor-pg").Priority(highPriority).Obj(), @@ -112,7 +112,7 @@ func TestPodGroupEvaluator_SelectVictimsOnDomain(t *testing.T) { st.MakePod().Name("p3").UID("v3").Node("node3").Priority(midPriority).Obj(), }, initPodGroups: []*schedulingapi.PodGroup{ - st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), + st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Priority(lowPriority).Obj(), }, preemptor: newPodGroupPreemptor( st.MakePodGroup().Name("preemptor-pg").Priority(highPriority).Obj(), @@ -134,7 +134,7 @@ func TestPodGroupEvaluator_SelectVictimsOnDomain(t *testing.T) { st.MakePod().Name("p2").UID("v2").Node("node2").Priority(lowPriority).PodGroupName("pg1").StartTime(metav1.Unix(0, 0)).Obj(), }, initPodGroups: []*schedulingapi.PodGroup{ - st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), + st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Priority(lowPriority).Obj(), }, preemptor: newPodGroupPreemptor( st.MakePodGroup().Name("preemptor-pg").Priority(highPriority).Obj(), @@ -158,9 +158,9 @@ func TestPodGroupEvaluator_SelectVictimsOnDomain(t *testing.T) { st.MakePod().Name("p5").UID("v5").Node("node5").Priority(highPriority).PodGroupName("pg3").StartTime(metav1.Unix(0, 0)).Obj(), }, initPodGroups: []*schedulingapi.PodGroup{ - st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), - st.MakePodGroup().Name("pg2").UID("pg2").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), - st.MakePodGroup().Name("pg3").UID("pg3").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), + st.MakePodGroup().Name("pg1").UID("pg1").DisruptionMode(schedulingapi.DisruptionModePod).Priority(lowPriority).Obj(), + st.MakePodGroup().Name("pg2").UID("pg2").DisruptionMode(schedulingapi.DisruptionModePod).Priority(lowPriority).Obj(), + st.MakePodGroup().Name("pg3").UID("pg3").DisruptionMode(schedulingapi.DisruptionModePod).Priority(highPriority).Obj(), }, preemptor: newPodGroupPreemptor( st.MakePodGroup().Name("preemptor-pg").Priority(highPriority).Obj(), diff --git a/pkg/scheduler/framework/preemption/types.go b/pkg/scheduler/framework/preemption/types.go index f8034e03c19..27116765cb7 100644 --- a/pkg/scheduler/framework/preemption/types.go +++ b/pkg/scheduler/framework/preemption/types.go @@ -83,23 +83,22 @@ type domain struct { allPossibleVictims []*victim } -// isPodGroupPreemptiblePod checks if a pod is a part of a pod group that should -// be treated as a single unit for preemption purposes. -// If the pod is a part of such a pod group, it returns the pod group and true. -// In all other cases, it returns nil and false. -func isPodGroupPreemptiblePod(p *v1.Pod, pgLister schedulinglisters.PodGroupLister) (*schedulingapi.PodGroup, bool) { +// getPodGroup checks if a pod specifies a scheduling group and returns the corresponding PodGroup object if found. +func getPodGroup(p *v1.Pod, pgLister schedulinglisters.PodGroupLister) *schedulingapi.PodGroup { if p.Spec.SchedulingGroup == nil { - return nil, false + return nil } pgName := p.Spec.SchedulingGroup.PodGroupName pg, err := pgLister.PodGroups(p.Namespace).Get(*pgName) if err != nil { - return nil, false + return nil } - if mode := pg.Spec.DisruptionMode; mode == nil || *mode != schedulingapi.DisruptionModePodGroup { - return nil, false - } - return pg, true + return pg +} + +// isDisruptionModePodGroup checks if the PodGroup disruption mode is set to PodGroup. +func isDisruptionModePodGroup(pg *schedulingapi.PodGroup) bool { + return pg != nil && pg.Spec.DisruptionMode != nil && *pg.Spec.DisruptionMode == schedulingapi.DisruptionModePodGroup } // newDomainForWorkloadPreemption creates a new domain for workload preemption. @@ -107,6 +106,8 @@ func isPodGroupPreemptiblePod(p *v1.Pod, pgLister schedulinglisters.PodGroupList // on the pods and their scheduling groups. // Pods that are part of a pod group with the PodGroup disruption mode are grouped // together into a single victim. Otherwise, they are treated as individual victims. +// In both cases, the priority of the victim is determined by the PodGroup priority +// if the pod belongs to a PodGroup. func newDomainForWorkloadPreemption(nodes []fwk.NodeInfo, pgLister schedulinglisters.PodGroupLister, name string) *domain { victimMap := map[types.UID]*victim{} for _, node := range nodes { @@ -114,11 +115,15 @@ func newDomainForWorkloadPreemption(nodes []fwk.NodeInfo, pgLister schedulinglis // TODO: Calling the lister here is not ideal given we do this // for every pod in the cluster. Instead, we should be getting // this information from the snapshot. - pg, ok := isPodGroupPreemptiblePod(p.GetPod(), pgLister) - if !ok { + pg := getPodGroup(p.GetPod(), pgLister) + if pg == nil { victimMap[p.GetPod().UID] = newVictim([]fwk.PodInfo{p}, corev1helpers.PodPriority(p.GetPod()), []fwk.NodeInfo{node}) continue } + if !isDisruptionModePodGroup(pg) { + victimMap[p.GetPod().UID] = newVictim([]fwk.PodInfo{p}, util.PodGroupPriority(pg), []fwk.NodeInfo{node}) + continue + } victim, ok := victimMap[pg.UID] if ok { victim.pods = append(victim.pods, p) diff --git a/pkg/scheduler/framework/preemption/types_test.go b/pkg/scheduler/framework/preemption/types_test.go index c53349f4962..24a7bb3fdf4 100644 --- a/pkg/scheduler/framework/preemption/types_test.go +++ b/pkg/scheduler/framework/preemption/types_test.go @@ -29,65 +29,77 @@ import ( st "k8s.io/kubernetes/pkg/scheduler/testing" ) -func TestIsPodGroupPreemptiblePod(t *testing.T) { +func TestGetPodGroup(t *testing.T) { tests := []struct { name string pod *v1.Pod podGroups map[string]*schedulingapi.PodGroup wantPodGroup *schedulingapi.PodGroup - wantOk bool }{ { name: "pod without scheduling group", pod: st.MakePod().Name("p1").Namespace("default").Obj(), wantPodGroup: nil, - wantOk: false, }, { name: "pod group not found", pod: st.MakePod().Name("p1").Namespace("default").PodGroupName("pg1").Obj(), podGroups: map[string]*schedulingapi.PodGroup{}, wantPodGroup: nil, - wantOk: false, }, { - name: "pod group with nil disruption mode", + name: "pod group found", pod: st.MakePod().Name("p1").Namespace("default").PodGroupName("pg1").Obj(), podGroups: map[string]*schedulingapi.PodGroup{ "pg1": st.MakePodGroup().Name("pg1").Namespace("default").Obj(), }, - wantPodGroup: nil, - wantOk: false, - }, - { - name: "pod group with DisruptionModePod", - pod: st.MakePod().Name("p1").Namespace("default").PodGroupName("pg1").Obj(), - podGroups: map[string]*schedulingapi.PodGroup{ - "pg1": st.MakePodGroup().Name("pg1").Namespace("default").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), - }, - wantPodGroup: nil, - wantOk: false, - }, - { - name: "pod group with DisruptionModePodGroup", - pod: st.MakePod().Name("p1").Namespace("default").PodGroupName("pg1").Obj(), - podGroups: map[string]*schedulingapi.PodGroup{ - "pg1": st.MakePodGroup().Name("pg1").Namespace("default").DisruptionMode(schedulingapi.DisruptionModePodGroup).Obj(), - }, - wantPodGroup: st.MakePodGroup().Name("pg1").Namespace("default").DisruptionMode(schedulingapi.DisruptionModePodGroup).Obj(), - wantOk: true, + wantPodGroup: st.MakePodGroup().Name("pg1").Namespace("default").Obj(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pgLister := &mockPodGroupLister{podGroups: tt.podGroups} - podGroup, ok := isPodGroupPreemptiblePod(tt.pod, pgLister) - if ok != tt.wantOk { - t.Errorf("isPodGroupPreemptiblePod() gotOk = %v, want %v", ok, tt.wantOk) - } + podGroup := getPodGroup(tt.pod, pgLister) if diff := cmp.Diff(tt.wantPodGroup, podGroup); diff != "" { - t.Errorf("isPodGroupPreemptiblePod() gotPodGroup mismatch (-want +got):\n%s", diff) + t.Errorf("getPodGroup() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestIsDisruptionModePodGroup(t *testing.T) { + tests := []struct { + name string + pg *schedulingapi.PodGroup + wantModePG bool + }{ + { + name: "nil pod group", + pg: nil, + wantModePG: false, + }, + { + name: "pod group with nil disruption mode", + pg: st.MakePodGroup().Name("pg1").Namespace("default").Obj(), + wantModePG: false, + }, + { + name: "pod group with DisruptionModePod", + pg: st.MakePodGroup().Name("pg1").Namespace("default").DisruptionMode(schedulingapi.DisruptionModePod).Obj(), + wantModePG: false, + }, + { + name: "pod group with DisruptionModePodGroup", + pg: st.MakePodGroup().Name("pg1").Namespace("default").DisruptionMode(schedulingapi.DisruptionModePodGroup).Obj(), + wantModePG: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotModePG := isDisruptionModePodGroup(tt.pg); gotModePG != tt.wantModePG { + t.Errorf("isDisruptionModePodGroup() = %v, want %v", gotModePG, tt.wantModePG) } }) } @@ -168,8 +180,8 @@ func TestNewDomainForWorkloadPreemption(t *testing.T) { }, domainName: "test-domain", wantVictims: []expectedVictim{ - {pods: sets.New("p1"), affectedNodes: sets.New("node1"), priority: 10}, - {pods: sets.New("p2"), affectedNodes: sets.New("node2"), priority: 20}, + {pods: sets.New("p1"), affectedNodes: sets.New("node1"), priority: 50}, + {pods: sets.New("p2"), affectedNodes: sets.New("node2"), priority: 50}, }, }, { @@ -191,7 +203,7 @@ func TestNewDomainForWorkloadPreemption(t *testing.T) { domainName: "test-domain", wantVictims: []expectedVictim{ {pods: sets.New("p1", "p2"), affectedNodes: sets.New("node1", "node2"), priority: 50}, - {pods: sets.New("p3"), affectedNodes: sets.New("node1"), priority: 20}, + {pods: sets.New("p3"), affectedNodes: sets.New("node1"), priority: 60}, {pods: sets.New("p4"), affectedNodes: sets.New("node2"), priority: 30}, }, },