Merge pull request #136269 from pohly/dra-scheduler-double-allocation-fixes

DRA scheduler: double allocation fixes
This commit is contained in:
Kubernetes Prow Robot 2026-01-26 20:59:50 +05:30 committed by GitHub
commit 53b29a3a2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 307 additions and 36 deletions

View file

@ -85,11 +85,18 @@ func foreachAllocatedDevice(claim *resourceapi.ResourceClaim,
// This is cheaper than repeatedly calling List, making strings unique, and building the set
// each time PreFilter is called.
//
// To simplify detecting concurrent changes, each modification bumps a revision counter,
// similar to ResourceVersion in the apiserver. Get and Capacities include the
// current value in their result. A caller than can compare againt the current value
// to determine whether some prior results are still up-to-date, without having to get
// and compare them.
//
// All methods are thread-safe. Get returns a cloned set.
type allocatedDevices struct {
logger klog.Logger
mutex sync.RWMutex
revision int64
ids sets.Set[structured.DeviceID]
shareIDs sets.Set[structured.SharedDeviceID]
capacities structured.ConsumedCapacityCollection
@ -106,18 +113,25 @@ func newAllocatedDevices(logger klog.Logger) *allocatedDevices {
}
}
func (a *allocatedDevices) Get() sets.Set[structured.DeviceID] {
func (a *allocatedDevices) Get() (sets.Set[structured.DeviceID], int64) {
a.mutex.RLock()
defer a.mutex.RUnlock()
return a.ids.Clone()
return a.ids.Clone(), a.revision
}
func (a *allocatedDevices) Capacities() structured.ConsumedCapacityCollection {
func (a *allocatedDevices) Capacities() (structured.ConsumedCapacityCollection, int64) {
a.mutex.RLock()
defer a.mutex.RUnlock()
return a.capacities.Clone()
return a.capacities.Clone(), a.revision
}
func (a *allocatedDevices) Revision() int64 {
a.mutex.RLock()
defer a.mutex.RUnlock()
return a.revision
}
func (a *allocatedDevices) handlers() cache.ResourceEventHandler {
@ -200,8 +214,13 @@ func (a *allocatedDevices) addDevices(claim *resourceapi.ResourceClaim) {
},
)
if len(deviceIDs) == 0 && len(shareIDs) == 0 && len(deviceCapacities) == 0 {
return
}
a.mutex.Lock()
defer a.mutex.Unlock()
a.revision++
for _, deviceID := range deviceIDs {
a.ids.Insert(deviceID)
}
@ -241,8 +260,14 @@ func (a *allocatedDevices) removeDevices(claim *resourceapi.ResourceClaim) {
a.logger.V(6).Info("Observed consumed capacity release", "device id", capacity.DeviceID, "consumed capacity", capacity.ConsumedCapacity, "claim", klog.KObj(claim))
deviceCapacities = append(deviceCapacities, capacity)
})
if len(deviceIDs) == 0 && len(shareIDs) == 0 && len(deviceCapacities) == 0 {
return
}
a.mutex.Lock()
defer a.mutex.Unlock()
a.revision++
for _, deviceID := range deviceIDs {
a.ids.Delete(deviceID)
}

View file

@ -18,6 +18,7 @@ package dynamicresources
import (
"context"
"errors"
"fmt"
"sync"
@ -218,10 +219,29 @@ func (c *claimTracker) List() ([]*resourceapi.ResourceClaim, error) {
return result, nil
}
// errClaimTrackerConcurrentModification gets returned if ListAllAllocatedDevices
// or GatherAllocatedState need to be retried.
//
// There is a rare race when a claim is initially in-flight:
// - allocated is created from cache (claim not there)
// - someone removes from the in-flight claims and adds to the cache
// - we start checking in-flight claims (claim not there anymore)
// => claim ignored
//
// A proper fix would be to rewrite the assume cache, allocatedDevices,
// and the in-flight map so that they are under a single lock. But that's
// a pretty big change and prevents reusing the assume cache. So instead
// we check for changes in the set of allocated devices and keep trying
// until we get an attempt with no concurrent changes.
//
// A claim being first in the cache, then only in-flight cannot happen,
// so we don't need to re-check the in-flight claims.
var errClaimTrackerConcurrentModification = errors.New("conflicting concurrent modification")
func (c *claimTracker) ListAllAllocatedDevices() (sets.Set[structured.DeviceID], error) {
// Start with a fresh set that matches the current known state of the
// world according to the informers.
allocated := c.allocatedDevices.Get()
allocated, revision := c.allocatedDevices.Get()
// Whatever is in flight also has to be checked.
c.inFlightAllocations.Range(func(key, value any) bool {
@ -232,16 +252,26 @@ func (c *claimTracker) ListAllAllocatedDevices() (sets.Set[structured.DeviceID],
}, false, func(structured.SharedDeviceID) {}, func(structured.DeviceConsumedCapacity) {})
return true
})
// There's no reason to return an error in this implementation, but the error might be helpful for other implementations.
return allocated, nil
if revision == c.allocatedDevices.Revision() {
// Our current result is valid, nothing changed in the meantime.
return allocated, nil
}
return nil, errClaimTrackerConcurrentModification
}
func (c *claimTracker) GatherAllocatedState() (*structured.AllocatedState, error) {
// Start with a fresh set that matches the current known state of the
// world according to the informers.
allocated := c.allocatedDevices.Get()
allocated, revision1 := c.allocatedDevices.Get()
allocatedSharedDeviceIDs := sets.New[structured.SharedDeviceID]()
aggregatedCapacity := c.allocatedDevices.Capacities()
aggregatedCapacity, revision2 := c.allocatedDevices.Capacities()
if revision1 != revision2 {
// Already not consistent. Try again.
return nil, errClaimTrackerConcurrentModification
}
enabledConsumableCapacity := utilfeature.DefaultFeatureGate.Enabled(features.DRAConsumableCapacity)
@ -263,12 +293,16 @@ func (c *claimTracker) GatherAllocatedState() (*structured.AllocatedState, error
return true
})
// There's no reason to return an error in this implementation, but the error might be helpful for other implementations.
return &structured.AllocatedState{
AllocatedDevices: allocated,
AllocatedSharedDeviceIDs: allocatedSharedDeviceIDs,
AggregatedCapacity: aggregatedCapacity,
}, nil
if revision1 == c.allocatedDevices.Revision() {
// Our current result is valid, nothing changed in the meantime.
return &structured.AllocatedState{
AllocatedDevices: allocated,
AllocatedSharedDeviceIDs: allocatedSharedDeviceIDs,
AggregatedCapacity: aggregatedCapacity,
}, nil
}
return nil, errClaimTrackerConcurrentModification
}
func (c *claimTracker) AssumeClaimAfterAPICall(claim *resourceapi.ResourceClaim) error {

View file

@ -518,25 +518,45 @@ func (pl *DynamicResources) PreFilter(ctx context.Context, state fwk.CycleState,
// Claims (and thus their devices) are treated as "allocated" if they are in the assume cache
// or currently their allocation is in-flight. This does not change
// during filtering, so we can determine that once.
//
// This might have to be retried in the unlikely case that some concurrent modification made
// the result invalid.
var allocatedState *structured.AllocatedState
if pl.fts.EnableDRAConsumableCapacity {
allocatedState, err = pl.draManager.ResourceClaims().GatherAllocatedState()
if err != nil {
return nil, statusError(logger, err)
}
if allocatedState == nil {
return nil, statusError(logger, errors.New("nil allocated state"))
}
} else {
allocatedDevices, err := pl.draManager.ResourceClaims().ListAllAllocatedDevices()
if err != nil {
return nil, statusError(logger, err)
}
allocatedState = &structured.AllocatedState{
AllocatedDevices: allocatedDevices,
AllocatedSharedDeviceIDs: sets.New[structured.SharedDeviceID](),
AggregatedCapacity: structured.NewConsumedCapacityCollection(),
err = wait.PollUntilContextTimeout(ctx, time.Microsecond, 5*time.Second, true /* immediate */, func(context.Context) (bool, error) {
if pl.fts.EnableDRAConsumableCapacity {
allocatedState, err = pl.draManager.ResourceClaims().GatherAllocatedState()
if err != nil {
if errors.Is(err, errClaimTrackerConcurrentModification) {
logger.V(6).Info("Conflicting modification during GatherAllocatedState, trying again")
return false, nil
}
return false, err
}
if allocatedState == nil {
return false, errors.New("nil allocated state")
}
// Done.
return true, nil
} else {
allocatedDevices, err := pl.draManager.ResourceClaims().ListAllAllocatedDevices()
if err != nil {
if errors.Is(err, errClaimTrackerConcurrentModification) {
logger.V(6).Info("Conflicting modification during ListAllAllocatedDevices, trying again")
return false, nil
}
return false, err
}
allocatedState = &structured.AllocatedState{
AllocatedDevices: allocatedDevices,
AllocatedSharedDeviceIDs: sets.New[structured.SharedDeviceID](),
AggregatedCapacity: structured.NewConsumedCapacityCollection(),
}
// Done.
return true, nil
}
})
if err != nil {
return nil, statusError(logger, fmt.Errorf("gather allocation state: %w", err))
}
slices, err := pl.draManager.ResourceSlices().ListWithDeviceTaintRules()
if err != nil {

View file

@ -124,6 +124,9 @@ type AssumeCache struct {
// Synchronizes updates to all fields below.
rwMutex sync.RWMutex
// cond is used by emitEvents.
cond *sync.Cond
// All registered event handlers.
eventHandlers []cache.ResourceEventHandler
handlerRegistration cache.ResourceEventHandlerRegistration
@ -149,6 +152,9 @@ type AssumeCache struct {
// of events would no longer be guaranteed.
eventQueue buffer.Ring[func()]
// emittingEvents is true while one emitEvents call is actively emitting events.
emittingEvents bool
// describes the object stored
description string
@ -196,6 +202,7 @@ func NewAssumeCache(logger klog.Logger, informer Informer, description, indexNam
indexName: indexName,
eventQueue: *buffer.NewRing[func()](buffer.RingOptions{InitialSize: 0, NormalSize: 4}),
}
c.cond = sync.NewCond(&c.rwMutex)
indexers := cache.Indexers{}
if indexName != "" && indexFunc != nil {
indexers[indexName] = c.objInfoIndexFunc
@ -508,8 +515,31 @@ func (c *AssumeCache) AddEventHandler(handler cache.ResourceEventHandler) cache.
}
// emitEvents delivers all pending events that are in the queue, in the order
// in which they were stored there (FIFO).
// in which they were stored there (FIFO). Only one goroutine at a time is
// delivering events, to ensure correct order.
func (c *AssumeCache) emitEvents() {
c.rwMutex.Lock()
for c.emittingEvents {
// Wait for the active caller of emitEvents to finish.
// When it is done, it may or may not have drained
// the events pushed by our caller.
// We'll check below ourselves.
c.cond.Wait()
}
c.emittingEvents = true
c.rwMutex.Unlock()
defer func() {
c.rwMutex.Lock()
c.emittingEvents = false
// Hand over the batton to one other goroutine, if there is one.
// We don't need to wake up more than one because only one of
// them would be able to grab the "emittingEvents" responsibility.
c.cond.Signal()
c.rwMutex.Unlock()
}()
// When we get here, this instance of emitEvents is the active one.
for {
c.rwMutex.Lock()
deliver, ok := c.eventQueue.ReadOne()

View file

@ -265,6 +265,63 @@ func TestAssume(t *testing.T) {
}
}
// TestAssumeRace simulates this sequence of events:
// - Informer update arrives, event handler gets invoked and is slow.
// - Assume for the same object is called. It must block until
// the informer-triggered event is delivered.
func TestAssumeRace(t *testing.T) { ktesting.Init(t).SyncTest("", testAssumeRace) }
func testAssumeRace(tCtx ktesting.TContext) {
var informer testInformer
testObj := makeObj("pvc1", "1", "")
ac := NewAssumeCache(tCtx.Logger(), &informer, "TestObject", "", nil)
blockEvent := tCtx.WithCancel()
defer blockEvent.Cancel("test done")
eventBlocked := false
eventDone := false
ac.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: func(obj any) {
eventBlocked = true
<-blockEvent.Done()
eventDone = true
},
})
// Here a real client with to a Create. What the assume cache may or may not
// see before Assume is the new object from the informer - let's pretend that
// comes first.
go informer.add(testObj)
// Wait for processing to finish.
tCtx.Wait()
if !eventBlocked {
tCtx.Fatal("Event handler should have been called and wasn't.")
}
// Assume should block until we unblock the event delivery.
assumeDone := false
go func() {
err := ac.Assume(testObj)
tCtx.AssertNoError(err, "Assume failed")
assumeDone = true
}()
// Wait for Assume to be blocked in its implementation.
tCtx.Wait()
if assumeDone {
tCtx.Fatal("Assume should have blocked and didn't.")
}
// Unblock both goroutines.
blockEvent.Cancel("proceed")
tCtx.Wait()
if !eventDone {
tCtx.Fatal("Event should have been delivered and wasn't.")
}
if !assumeDone {
tCtx.Fatal("Assume should have returned and didn't.")
}
}
func TestRestore(t *testing.T) {
tCtx, cache, informer := newTest(t)
var events mockEventHandler

View file

@ -42,6 +42,8 @@ var (
// testDeviceBindingConditions is the entry point for running each integration test that verifies DeviceBindingConditions.
// Some of these tests use device taints, and they assume that DRADeviceTaints is enabled.
func testDeviceBindingConditions(tCtx ktesting.TContext, enabled bool) {
tCtx.Parallel()
tCtx.Run("BasicFlow", func(tCtx ktesting.TContext) { testDeviceBindingConditionsBasicFlow(tCtx, enabled) })
if enabled {
tCtx.Run("FailureTaints", func(tCtx ktesting.TContext) { testDeviceBindingFailureConditionsReschedule(tCtx, true) })

View file

@ -118,8 +118,11 @@ var (
Namespace(namespace).
RequestWithPrioritizedList(st.SubRequest("subreq-1", className, 1)).
Obj()
)
numNodes = 8
const (
numNodes = 8
maxPodsPerNode = 5000 // This should never be the limiting factor, no matter how many tests run in parallel.
)
func TestDRA(t *testing.T) {
@ -154,6 +157,7 @@ func TestDRA(t *testing.T) {
// Number of devices per slice is chosen so that Filter takes a few seconds:
// without a timeout, the test doesn't run too long, but long enough that a short timeout triggers.
tCtx.Run("FilterTimeout", func(tCtx ktesting.TContext) { testFilterTimeout(tCtx, 9) })
tCtx.Run("UsesAllResources", testUsesAllResources)
},
},
"GA": {
@ -176,6 +180,7 @@ func TestDRA(t *testing.T) {
tCtx = tCtx.WithNamespace(namespace)
TestCreateResourceSlices(tCtx, 100)
})
tCtx.Run("UsesAllResources", testUsesAllResources)
},
},
"v1beta1": {
@ -231,6 +236,8 @@ func TestDRA(t *testing.T) {
features.DRAExtendedResource: true,
},
f: func(tCtx ktesting.TContext) {
// These tests must run in parallel as much as possible to keep overall runtime low!
tCtx.Run("AdminAccess", func(tCtx ktesting.TContext) { testAdminAccess(tCtx, true) })
tCtx.Run("Convert", testConvert)
tCtx.Run("ControllerManagerMetrics", testControllerManagerMetrics)
@ -249,6 +256,7 @@ func TestDRA(t *testing.T) {
// in the experimental channel has an improvement that requires a higher number here than
// in the incubating and stable channels.
tCtx.Run("FilterTimeout", func(tCtx ktesting.TContext) { testFilterTimeout(tCtx, 20) })
tCtx.Run("UsesAllResources", testUsesAllResources)
},
},
} {
@ -309,7 +317,7 @@ func createNodes(tCtx ktesting.TContext) {
Capacity: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("100"),
v1.ResourceMemory: resource.MustParse("1000"),
v1.ResourcePods: resource.MustParse("100"),
v1.ResourcePods: *resource.NewScaledQuantity(maxPodsPerNode, 0),
},
Phase: v1.NodeRunning,
Conditions: []v1.NodeCondition{
@ -486,6 +494,8 @@ func testConvert(tCtx ktesting.TContext) {
// when the AdminAccess feature is enabled, it also checks that the field
// is only allowed to be used in namespace with the Resource Admin Access label
func testAdminAccess(tCtx ktesting.TContext, adminAccessEnabled bool) {
tCtx.Parallel()
namespace := createTestNamespace(tCtx, nil)
claim1 := claim.DeepCopy()
claim1.Namespace = namespace
@ -1186,6 +1196,8 @@ func testPublishResourceSlices(tCtx ktesting.TContext, haveLatestAPI bool, disab
//
// When enabled, it tries server-side-apply (SSA) with different clients. This is what DRA drivers should be using.
func testResourceClaimDeviceStatus(tCtx ktesting.TContext, enabled bool) {
tCtx.Parallel()
namespace := createTestNamespace(tCtx, nil)
claim := &resourceapi.ResourceClaim{
@ -1380,7 +1392,8 @@ func testMaxResourceSlice(tCtx ktesting.TContext) {
}
}
// testControllerManagerMetrics tests ResourceClaim metrics
// testControllerManagerMetrics tests ResourceClaim metrics.
// It must run sequentially.
func testControllerManagerMetrics(tCtx ktesting.TContext) {
namespace := createTestNamespace(tCtx, nil)
class, _ := createTestClass(tCtx, namespace)

View file

@ -0,0 +1,90 @@
/*
Copyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dra
import (
"fmt"
"time"
"github.com/onsi/gomega"
v1 "k8s.io/api/core/v1"
resourceapi "k8s.io/api/resource/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/dynamic-resource-allocation/structured"
"k8s.io/klog/v2"
st "k8s.io/kubernetes/pkg/scheduler/testing"
"k8s.io/kubernetes/test/utils/ktesting"
)
func testUsesAllResources(tCtx ktesting.TContext) {
tCtx.Parallel()
namespace := createTestNamespace(tCtx, nil)
nodes, err := tCtx.Client().CoreV1().Nodes().List(tCtx, metav1.ListOptions{})
tCtx.ExpectNoError(err, "list nodes")
numDevicesPerNode := 100 // One pod gets scheduled per device.
class, driverName := createTestClass(tCtx, namespace)
for _, node := range nodes.Items {
// Globally unique device names make debugging simpler...
devices := make([]string, numDevicesPerNode)
for i := range numDevicesPerNode {
devices[i] = fmt.Sprintf("%s-device-%03d", node.Name, i)
}
slice := st.MakeResourceSlice(node.Name, driverName).Devices(devices...)
createSlice(tCtx, slice.Obj())
}
var claims []*resourceapi.ResourceClaim
var pods []*v1.Pod
for i := range len(nodes.Items) * numDevicesPerNode {
tCtx := tCtx.WithStep(fmt.Sprintf("#%04d", i))
claim := st.MakeResourceClaim().
Name(fmt.Sprintf("claim-%04d", i)).
Namespace(namespace).
Request(class.Name).
Obj()
claim, err := tCtx.Client().ResourceV1().ResourceClaims(namespace).Create(tCtx, claim, metav1.CreateOptions{})
tCtx.ExpectNoError(err, "create claim")
claims = append(claims, claim)
pod := createPod(tCtx, namespace, fmt.Sprintf("-%04d", i), podWithClaimName, claim)
pods = append(pods, pod)
}
startScheduler(tCtx)
// Eventually, all pods should be scheduled and thus all claims allocated.
allocated := make(map[structured.DeviceID]*resourceapi.ResourceClaim, len(claims))
tCtx = tCtx.WithStep("check claim allocation").WithTimeout(time.Duration(len(pods))*5*time.Second, "scheduling timeout for all pods")
for _, claim := range claims {
var actualClaim *resourceapi.ResourceClaim
tCtx.Eventually(func(tCtx ktesting.TContext) (*resourceapi.ResourceClaim, error) {
c, err := tCtx.Client().ResourceV1().ResourceClaims(claim.Namespace).Get(tCtx, claim.Name, metav1.GetOptions{})
actualClaim = c
return c, err
}).Should(gomega.HaveField("Status.Allocation", gomega.Not(gomega.BeNil())))
tCtx.Expect(actualClaim.Status.Allocation.Devices.Results).To(gomega.HaveLen(1))
result := actualClaim.Status.Allocation.Devices.Results[0]
id := structured.MakeDeviceID(result.Driver, result.Pool, result.Device)
if otherClaim, ok := allocated[id]; ok {
tCtx.Fatalf("device %s was allocated to claims %s and %s", id, klog.KObj(actualClaim), klog.KObj(otherClaim))
}
allocated[id] = actualClaim
}
}