controller/tainteviction: Improve goroutine mgmt

Make sure all threads are terminated when Run returns.
This commit is contained in:
Ondra Kupka 2025-10-28 16:39:28 +01:00
parent d2a443db75
commit 9d4ff6ecf2
4 changed files with 156 additions and 66 deletions

View file

@ -278,22 +278,24 @@ func New(ctx context.Context, c clientset.Interface, podInformer corev1informers
// Run starts the controller which will run in loop until `stopCh` is closed.
func (tc *Controller) Run(ctx context.Context) {
defer utilruntime.HandleCrash()
logger := klog.FromContext(ctx)
logger.Info("Starting", "controller", tc.name)
defer logger.Info("Shutting down controller", "controller", tc.name)
// Start events processing pipeline.
tc.broadcaster.StartStructuredLogging(3)
if tc.client != nil {
logger.Info("Sending events to api server")
tc.broadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: tc.client.CoreV1().Events("")})
} else {
logger.Error(nil, "kubeClient is nil", "controller", tc.name)
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
}
tc.broadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: tc.client.CoreV1().Events("")})
logger.Info("Sending events to API server")
defer tc.broadcaster.Shutdown()
defer tc.nodeUpdateQueue.ShutDown()
defer tc.podUpdateQueue.ShutDown()
var wg sync.WaitGroup
defer func() {
logger.Info("Shutting down controller", "controller", tc.name)
tc.nodeUpdateQueue.ShutDown()
tc.podUpdateQueue.ShutDown()
tc.taintEvictionQueue.CancelAndWait()
wg.Wait()
}()
// wait for the cache to be synced
if !cache.WaitForNamedCacheSyncWithContext(ctx, tc.podListerSynced, tc.nodeListerSynced) {
@ -305,9 +307,8 @@ func (tc *Controller) Run(ctx context.Context) {
tc.podUpdateChannels = append(tc.podUpdateChannels, make(chan podUpdateItem, podUpdateChannelSize))
}
// Functions that are responsible for taking work items out of the workqueues and putting them
// into channels.
go func(stopCh <-chan struct{}) {
// Functions that are responsible for taking work items out of the workqueues and putting them into channels.
wg.Go(func() {
for {
nodeUpdate, shutdown := tc.nodeUpdateQueue.Get()
if shutdown {
@ -315,16 +316,16 @@ func (tc *Controller) Run(ctx context.Context) {
}
hash := hash(nodeUpdate.nodeName, UpdateWorkerSize)
select {
case <-stopCh:
case <-ctx.Done():
tc.nodeUpdateQueue.Done(nodeUpdate)
return
case tc.nodeUpdateChannels[hash] <- nodeUpdate:
// tc.nodeUpdateQueue.Done is called by the nodeUpdateChannels worker
}
}
}(ctx.Done())
})
go func(stopCh <-chan struct{}) {
wg.Go(func() {
for {
podUpdate, shutdown := tc.podUpdateQueue.Get()
if shutdown {
@ -336,33 +337,31 @@ func (tc *Controller) Run(ctx context.Context) {
// It's possible that even without this assumption this code is still correct.
hash := hash(podUpdate.nodeName, UpdateWorkerSize)
select {
case <-stopCh:
case <-ctx.Done():
tc.podUpdateQueue.Done(podUpdate)
return
case tc.podUpdateChannels[hash] <- podUpdate:
// tc.podUpdateQueue.Done is called by the podUpdateChannels worker
}
}
}(ctx.Done())
})
wg := sync.WaitGroup{}
wg.Add(UpdateWorkerSize)
for i := 0; i < UpdateWorkerSize; i++ {
go tc.worker(ctx, i, wg.Done, ctx.Done())
wg.Go(func() {
tc.worker(ctx, i)
})
}
wg.Wait()
<-ctx.Done()
}
func (tc *Controller) worker(ctx context.Context, worker int, done func(), stopCh <-chan struct{}) {
defer done()
func (tc *Controller) worker(ctx context.Context, worker int) {
// When processing events we want to prioritize Node updates over Pod updates,
// as NodeUpdates that interest the controller should be handled as soon as possible -
// we don't want user (or system) to wait until PodUpdate queue is drained before it can
// start evicting Pods from tainted Nodes.
for {
select {
case <-stopCh:
case <-ctx.Done():
return
case nodeUpdate := <-tc.nodeUpdateChannels[worker]:
tc.handleNodeUpdate(ctx, nodeUpdate)

View file

@ -21,6 +21,7 @@ import (
"fmt"
goruntime "runtime"
"sort"
"sync"
"testing"
"time"
@ -192,31 +193,40 @@ func TestCreatePod(t *testing.T) {
for _, item := range testCases {
t.Run(item.description, func(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fakeClientset := fake.NewSimpleClientset(&corev1.PodList{Items: []corev1.Pod{*item.pod}})
controller, podIndexer, _ := setupNewController(ctx, fakeClientset)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
controller.taintedNodes = item.taintedNodes
wg.Go(func() {
controller.Run(ctx)
})
podIndexer.Add(item.pod)
controller.PodUpdated(nil, item.pod)
verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete)
cancel()
})
}
}
func TestDeletePod(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fakeClientset := fake.NewSimpleClientset()
controller, _, _ := setupNewController(ctx, fakeClientset)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
controller.taintedNodes = map[string][]corev1.Taint{
"node1": {createNoExecuteTaint(1)},
}
@ -286,12 +296,20 @@ func TestUpdatePod(t *testing.T) {
// TODO: remove skip once the flaking test has been fixed.
t.Skip("Skip flaking test on Windows.")
}
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fakeClientset := fake.NewSimpleClientset(&corev1.PodList{Items: []corev1.Pod{*item.prevPod}})
controller, podIndexer, _ := setupNewController(context.TODO(), fakeClientset)
controller.recorder = testutil.NewFakeRecorder()
controller.taintedNodes = item.taintedNodes
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
podIndexer.Add(item.prevPod)
controller.PodUpdated(nil, item.prevPod)
@ -311,7 +329,6 @@ func TestUpdatePod(t *testing.T) {
controller.PodUpdated(item.prevPod, item.newPod)
verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete)
cancel()
})
}
}
@ -354,29 +371,47 @@ func TestCreateNode(t *testing.T) {
}
for _, item := range testCases {
ctx, cancel := context.WithCancel(context.Background())
fakeClientset := fake.NewSimpleClientset(&corev1.PodList{Items: item.pods})
controller, _, nodeIndexer := setupNewController(ctx, fakeClientset)
nodeIndexer.Add(item.node)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
controller.NodeUpdated(nil, item.node)
t.Run(item.description, func(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete)
fakeClientset := fake.NewClientset(&corev1.PodList{Items: item.pods})
controller, _, nodeIndexer := setupNewController(ctx, fakeClientset)
if err := nodeIndexer.Add(item.node); err != nil {
t.Fatalf("Failed to add node %q: %v", item.node.GetName(), err)
}
controller.recorder = testutil.NewFakeRecorder()
cancel()
wg.Go(func() {
controller.Run(ctx)
})
controller.NodeUpdated(nil, item.node)
verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete)
})
}
}
func TestDeleteNode(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fakeClientset := fake.NewSimpleClientset()
controller, _, _ := setupNewController(ctx, fakeClientset)
controller.recorder = testutil.NewFakeRecorder()
controller.taintedNodes = map[string][]corev1.Taint{
"node1": {createNoExecuteTaint(1)},
}
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
controller.NodeUpdated(testutil.NewNode("node1"), nil)
// await until controller.taintedNodes is empty
@ -389,7 +424,6 @@ func TestDeleteNode(t *testing.T) {
if err != nil {
t.Errorf("Failed to await for processing node deleted: %q", err)
}
cancel()
}
func TestUpdateNode(t *testing.T) {
@ -475,6 +509,8 @@ func TestUpdateNode(t *testing.T) {
for _, item := range testCases {
t.Run(item.description, func(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -482,7 +518,11 @@ func TestUpdateNode(t *testing.T) {
controller, _, nodeIndexer := setupNewController(ctx, fakeClientset)
nodeIndexer.Add(item.newNode)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
controller.NodeUpdated(item.oldNode, item.newNode)
if item.additionalSleep > 0 {
@ -514,11 +554,18 @@ func TestUpdateNodeWithMultipleTaints(t *testing.T) {
singleTaintedNode := testutil.NewNode("node1")
singleTaintedNode.Spec.Taints = []corev1.Taint{taint1}
ctx, cancel := context.WithCancel(context.TODO())
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fakeClientset := fake.NewSimpleClientset(pod)
controller, _, nodeIndexer := setupNewController(ctx, fakeClientset)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
// no taint
nodeIndexer.Add(untaintedNode)
@ -558,7 +605,6 @@ func TestUpdateNodeWithMultipleTaints(t *testing.T) {
t.Error("Unexpected deletion")
}
}
cancel()
}
func TestUpdateNodeWithMultiplePods(t *testing.T) {
@ -601,6 +647,9 @@ func TestUpdateNodeWithMultiplePods(t *testing.T) {
for _, item := range testCases {
t.Run(item.description, func(t *testing.T) {
t.Logf("Starting testcase %q", item.description)
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -609,7 +658,11 @@ func TestUpdateNodeWithMultiplePods(t *testing.T) {
controller, _, nodeIndexer := setupNewController(ctx, fakeClientset)
nodeIndexer.Add(item.newNode)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
controller.NodeUpdated(item.oldNode, item.newNode)
startedAt := time.Now()
@ -809,6 +862,8 @@ func TestEventualConsistency(t *testing.T) {
for _, item := range testCases {
t.Run(item.description, func(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -816,7 +871,10 @@ func TestEventualConsistency(t *testing.T) {
controller, podIndexer, nodeIndexer := setupNewController(ctx, fakeClientset)
nodeIndexer.Add(item.newNode)
controller.recorder = testutil.NewFakeRecorder()
go controller.Run(ctx)
wg.Go(func() {
controller.Run(ctx)
})
if item.prevPod != nil {
podIndexer.Add(item.prevPod)

View file

@ -19,6 +19,7 @@ package tainteviction
import (
"context"
"sync"
"sync/atomic"
"time"
"k8s.io/apimachinery/pkg/types"
@ -54,35 +55,54 @@ type TimedWorker struct {
CreatedAt time.Time
FireAt time.Time
Timer clock.Timer
cancelled atomic.Bool
}
// createWorker creates a TimedWorker that will execute `f` not earlier than `fireAt`.
// Returns nil if the work was started immediately and doesn't need a timer.
func createWorker(ctx context.Context, args *WorkArgs, createdAt time.Time, fireAt time.Time, f func(ctx context.Context, fireAt time.Time, args *WorkArgs) error, clock clock.WithDelayedExecution) *TimedWorker {
func createWorker(ctx context.Context, wg *sync.WaitGroup, args *WorkArgs, createdAt time.Time, fireAt time.Time, f func(ctx context.Context, fireAt time.Time, args *WorkArgs) error, clock clock.WithDelayedExecution) *TimedWorker {
delay := fireAt.Sub(createdAt)
logger := klog.FromContext(ctx)
fWithErrorLogging := func() {
err := f(ctx, fireAt, args)
if err != nil {
logger.Error(err, "TaintEvictionController: timed worker failed")
}
}
if delay <= 0 {
go fWithErrorLogging()
return nil
}
timer := clock.AfterFunc(delay, fWithErrorLogging)
return &TimedWorker{
worker := TimedWorker{
WorkItem: args,
CreatedAt: createdAt,
FireAt: fireAt,
Timer: timer,
}
// This dance with the cancelled flag is here so that we can be sure that once TimedWorker.Cancel returns,
// we either never get to any processing, or the thread is already registered with the WaitGroup.
// Otherwise, the following sequence can happen:
// 1. TimedWorker.Timer fires and gets to go wrapper().
// 2. TimedWorker.Cancel is called to stop the timer, but a goroutine is already running.
// It's started, but wg.Go hasn't been called yet to register the goroutine.
// 3. We call wg.Wait, which unblocks, because the inner wg.Go hasn't been called yet.
// This causes wg.Wait to unblock after TimedWorker.Cancel is called and still start a goroutine.
// So in our case we can still get to starting an unregistered goroutine, but it will exit immediately.
wrapper := func() {
wg.Go(func() {
if worker.cancelled.Load() {
return
}
if err := f(ctx, fireAt, args); err != nil {
logger.Error(err, "TaintEvictionController: timed worker failed")
}
})
}
if delay <= 0 {
wrapper()
return nil
}
worker.Timer = clock.AfterFunc(delay, wrapper)
return &worker
}
// Cancel cancels the execution of function by the `TimedWorker`
func (w *TimedWorker) Cancel() {
if w != nil {
// Mark the worker as cancelled.
// This ensures the worker is either already running or unstarted on return from Cancel.
w.cancelled.Store(true)
w.Timer.Stop()
}
}
@ -93,6 +113,7 @@ type TimedWorkerQueue struct {
// map of workers keyed by string returned by 'KeyFromWorkArgs' from the given worker.
// Entries may be nil if the work didn't need a timer and is already running.
workers map[string]*TimedWorker
workerWG sync.WaitGroup
workFunc func(ctx context.Context, fireAt time.Time, args *WorkArgs) error
clock clock.WithDelayedExecution
}
@ -134,7 +155,7 @@ func (q *TimedWorkerQueue) AddWork(ctx context.Context, args *WorkArgs, createdA
return
}
logger.V(4).Info("Adding TimedWorkerQueue item and to be fired at firedTime", "item", key, "createTime", createdAt, "firedTime", fireAt)
worker := createWorker(ctx, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock)
worker := createWorker(ctx, &q.workerWG, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock)
q.workers[key] = worker
}
@ -159,7 +180,7 @@ func (q *TimedWorkerQueue) UpdateWork(ctx context.Context, args *WorkArgs, creat
worker.Cancel()
}
logger.V(4).Info("Adding TimedWorkerQueue item and to be fired at firedTime", "item", key, "createTime", createdAt, "firedTime", fireAt)
worker := createWorker(ctx, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock)
worker := createWorker(ctx, &q.workerWG, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock)
q.workers[key] = worker
}
@ -189,3 +210,15 @@ func (q *TimedWorkerQueue) GetWorkerUnsafe(key string) *TimedWorker {
defer q.Unlock()
return q.workers[key]
}
// CancelAndWait cancels all workers and waits for all running threads to terminate before returning.
func (q *TimedWorkerQueue) CancelAndWait() {
// Wait must be called after Unlock, otherwise this hangs.
defer q.workerWG.Wait()
q.Lock()
defer q.Unlock()
for _, worker := range q.workers {
worker.Cancel()
}
q.workers = make(map[string]*TimedWorker)
}

View file

@ -116,7 +116,7 @@ func TestCancel(t *testing.T) {
}
}
func TestCancelAndReadd(t *testing.T) {
func TestCancelAndRead(t *testing.T) {
logger, ctx := ktesting.NewTestContext(t)
testVal := int32(0)
wg := sync.WaitGroup{}