diff --git a/pkg/controller/tainteviction/taint_eviction.go b/pkg/controller/tainteviction/taint_eviction.go index c4a24e14b2e..f2f098900e9 100644 --- a/pkg/controller/tainteviction/taint_eviction.go +++ b/pkg/controller/tainteviction/taint_eviction.go @@ -278,22 +278,24 @@ func New(ctx context.Context, c clientset.Interface, podInformer corev1informers // Run starts the controller which will run in loop until `stopCh` is closed. func (tc *Controller) Run(ctx context.Context) { defer utilruntime.HandleCrash() + logger := klog.FromContext(ctx) logger.Info("Starting", "controller", tc.name) - defer logger.Info("Shutting down controller", "controller", tc.name) // Start events processing pipeline. tc.broadcaster.StartStructuredLogging(3) - if tc.client != nil { - logger.Info("Sending events to api server") - tc.broadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: tc.client.CoreV1().Events("")}) - } else { - logger.Error(nil, "kubeClient is nil", "controller", tc.name) - klog.FlushAndExit(klog.ExitFlushTimeout, 1) - } + tc.broadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: tc.client.CoreV1().Events("")}) + logger.Info("Sending events to API server") defer tc.broadcaster.Shutdown() - defer tc.nodeUpdateQueue.ShutDown() - defer tc.podUpdateQueue.ShutDown() + + var wg sync.WaitGroup + defer func() { + logger.Info("Shutting down controller", "controller", tc.name) + tc.nodeUpdateQueue.ShutDown() + tc.podUpdateQueue.ShutDown() + tc.taintEvictionQueue.CancelAndWait() + wg.Wait() + }() // wait for the cache to be synced if !cache.WaitForNamedCacheSyncWithContext(ctx, tc.podListerSynced, tc.nodeListerSynced) { @@ -305,9 +307,8 @@ func (tc *Controller) Run(ctx context.Context) { tc.podUpdateChannels = append(tc.podUpdateChannels, make(chan podUpdateItem, podUpdateChannelSize)) } - // Functions that are responsible for taking work items out of the workqueues and putting them - // into channels. - go func(stopCh <-chan struct{}) { + // Functions that are responsible for taking work items out of the workqueues and putting them into channels. + wg.Go(func() { for { nodeUpdate, shutdown := tc.nodeUpdateQueue.Get() if shutdown { @@ -315,16 +316,16 @@ func (tc *Controller) Run(ctx context.Context) { } hash := hash(nodeUpdate.nodeName, UpdateWorkerSize) select { - case <-stopCh: + case <-ctx.Done(): tc.nodeUpdateQueue.Done(nodeUpdate) return case tc.nodeUpdateChannels[hash] <- nodeUpdate: // tc.nodeUpdateQueue.Done is called by the nodeUpdateChannels worker } } - }(ctx.Done()) + }) - go func(stopCh <-chan struct{}) { + wg.Go(func() { for { podUpdate, shutdown := tc.podUpdateQueue.Get() if shutdown { @@ -336,33 +337,31 @@ func (tc *Controller) Run(ctx context.Context) { // It's possible that even without this assumption this code is still correct. hash := hash(podUpdate.nodeName, UpdateWorkerSize) select { - case <-stopCh: + case <-ctx.Done(): tc.podUpdateQueue.Done(podUpdate) return case tc.podUpdateChannels[hash] <- podUpdate: // tc.podUpdateQueue.Done is called by the podUpdateChannels worker } } - }(ctx.Done()) + }) - wg := sync.WaitGroup{} - wg.Add(UpdateWorkerSize) for i := 0; i < UpdateWorkerSize; i++ { - go tc.worker(ctx, i, wg.Done, ctx.Done()) + wg.Go(func() { + tc.worker(ctx, i) + }) } - wg.Wait() + <-ctx.Done() } -func (tc *Controller) worker(ctx context.Context, worker int, done func(), stopCh <-chan struct{}) { - defer done() - +func (tc *Controller) worker(ctx context.Context, worker int) { // When processing events we want to prioritize Node updates over Pod updates, // as NodeUpdates that interest the controller should be handled as soon as possible - // we don't want user (or system) to wait until PodUpdate queue is drained before it can // start evicting Pods from tainted Nodes. for { select { - case <-stopCh: + case <-ctx.Done(): return case nodeUpdate := <-tc.nodeUpdateChannels[worker]: tc.handleNodeUpdate(ctx, nodeUpdate) diff --git a/pkg/controller/tainteviction/taint_eviction_test.go b/pkg/controller/tainteviction/taint_eviction_test.go index 6c933d5f23d..f10c5a630f7 100644 --- a/pkg/controller/tainteviction/taint_eviction_test.go +++ b/pkg/controller/tainteviction/taint_eviction_test.go @@ -21,6 +21,7 @@ import ( "fmt" goruntime "runtime" "sort" + "sync" "testing" "time" @@ -192,31 +193,40 @@ func TestCreatePod(t *testing.T) { for _, item := range testCases { t.Run(item.description, func(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fakeClientset := fake.NewSimpleClientset(&corev1.PodList{Items: []corev1.Pod{*item.pod}}) controller, podIndexer, _ := setupNewController(ctx, fakeClientset) controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) controller.taintedNodes = item.taintedNodes + wg.Go(func() { + controller.Run(ctx) + }) + podIndexer.Add(item.pod) controller.PodUpdated(nil, item.pod) verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete) - - cancel() }) } } func TestDeletePod(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() fakeClientset := fake.NewSimpleClientset() controller, _, _ := setupNewController(ctx, fakeClientset) controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) + wg.Go(func() { + controller.Run(ctx) + }) controller.taintedNodes = map[string][]corev1.Taint{ "node1": {createNoExecuteTaint(1)}, } @@ -286,12 +296,20 @@ func TestUpdatePod(t *testing.T) { // TODO: remove skip once the flaking test has been fixed. t.Skip("Skip flaking test on Windows.") } + + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fakeClientset := fake.NewSimpleClientset(&corev1.PodList{Items: []corev1.Pod{*item.prevPod}}) controller, podIndexer, _ := setupNewController(context.TODO(), fakeClientset) controller.recorder = testutil.NewFakeRecorder() controller.taintedNodes = item.taintedNodes - go controller.Run(ctx) + + wg.Go(func() { + controller.Run(ctx) + }) podIndexer.Add(item.prevPod) controller.PodUpdated(nil, item.prevPod) @@ -311,7 +329,6 @@ func TestUpdatePod(t *testing.T) { controller.PodUpdated(item.prevPod, item.newPod) verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete) - cancel() }) } } @@ -354,29 +371,47 @@ func TestCreateNode(t *testing.T) { } for _, item := range testCases { - ctx, cancel := context.WithCancel(context.Background()) - fakeClientset := fake.NewSimpleClientset(&corev1.PodList{Items: item.pods}) - controller, _, nodeIndexer := setupNewController(ctx, fakeClientset) - nodeIndexer.Add(item.node) - controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) - controller.NodeUpdated(nil, item.node) + t.Run(item.description, func(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete) + fakeClientset := fake.NewClientset(&corev1.PodList{Items: item.pods}) + controller, _, nodeIndexer := setupNewController(ctx, fakeClientset) + if err := nodeIndexer.Add(item.node); err != nil { + t.Fatalf("Failed to add node %q: %v", item.node.GetName(), err) + } + controller.recorder = testutil.NewFakeRecorder() - cancel() + wg.Go(func() { + controller.Run(ctx) + }) + + controller.NodeUpdated(nil, item.node) + + verifyPodActions(t, item.description, fakeClientset, item.expectPatch, item.expectDelete) + }) } } func TestDeleteNode(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fakeClientset := fake.NewSimpleClientset() controller, _, _ := setupNewController(ctx, fakeClientset) controller.recorder = testutil.NewFakeRecorder() controller.taintedNodes = map[string][]corev1.Taint{ "node1": {createNoExecuteTaint(1)}, } - go controller.Run(ctx) + + wg.Go(func() { + controller.Run(ctx) + }) + controller.NodeUpdated(testutil.NewNode("node1"), nil) // await until controller.taintedNodes is empty @@ -389,7 +424,6 @@ func TestDeleteNode(t *testing.T) { if err != nil { t.Errorf("Failed to await for processing node deleted: %q", err) } - cancel() } func TestUpdateNode(t *testing.T) { @@ -475,6 +509,8 @@ func TestUpdateNode(t *testing.T) { for _, item := range testCases { t.Run(item.description, func(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -482,7 +518,11 @@ func TestUpdateNode(t *testing.T) { controller, _, nodeIndexer := setupNewController(ctx, fakeClientset) nodeIndexer.Add(item.newNode) controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) + + wg.Go(func() { + controller.Run(ctx) + }) + controller.NodeUpdated(item.oldNode, item.newNode) if item.additionalSleep > 0 { @@ -514,11 +554,18 @@ func TestUpdateNodeWithMultipleTaints(t *testing.T) { singleTaintedNode := testutil.NewNode("node1") singleTaintedNode.Spec.Taints = []corev1.Taint{taint1} - ctx, cancel := context.WithCancel(context.TODO()) + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fakeClientset := fake.NewSimpleClientset(pod) controller, _, nodeIndexer := setupNewController(ctx, fakeClientset) controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) + + wg.Go(func() { + controller.Run(ctx) + }) // no taint nodeIndexer.Add(untaintedNode) @@ -558,7 +605,6 @@ func TestUpdateNodeWithMultipleTaints(t *testing.T) { t.Error("Unexpected deletion") } } - cancel() } func TestUpdateNodeWithMultiplePods(t *testing.T) { @@ -601,6 +647,9 @@ func TestUpdateNodeWithMultiplePods(t *testing.T) { for _, item := range testCases { t.Run(item.description, func(t *testing.T) { t.Logf("Starting testcase %q", item.description) + + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -609,7 +658,11 @@ func TestUpdateNodeWithMultiplePods(t *testing.T) { controller, _, nodeIndexer := setupNewController(ctx, fakeClientset) nodeIndexer.Add(item.newNode) controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) + + wg.Go(func() { + controller.Run(ctx) + }) + controller.NodeUpdated(item.oldNode, item.newNode) startedAt := time.Now() @@ -809,6 +862,8 @@ func TestEventualConsistency(t *testing.T) { for _, item := range testCases { t.Run(item.description, func(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -816,7 +871,10 @@ func TestEventualConsistency(t *testing.T) { controller, podIndexer, nodeIndexer := setupNewController(ctx, fakeClientset) nodeIndexer.Add(item.newNode) controller.recorder = testutil.NewFakeRecorder() - go controller.Run(ctx) + + wg.Go(func() { + controller.Run(ctx) + }) if item.prevPod != nil { podIndexer.Add(item.prevPod) diff --git a/pkg/controller/tainteviction/timed_workers.go b/pkg/controller/tainteviction/timed_workers.go index 3bb2e48d0a7..bd9b4a326ca 100644 --- a/pkg/controller/tainteviction/timed_workers.go +++ b/pkg/controller/tainteviction/timed_workers.go @@ -19,6 +19,7 @@ package tainteviction import ( "context" "sync" + "sync/atomic" "time" "k8s.io/apimachinery/pkg/types" @@ -54,35 +55,54 @@ type TimedWorker struct { CreatedAt time.Time FireAt time.Time Timer clock.Timer + cancelled atomic.Bool } // createWorker creates a TimedWorker that will execute `f` not earlier than `fireAt`. // Returns nil if the work was started immediately and doesn't need a timer. -func createWorker(ctx context.Context, args *WorkArgs, createdAt time.Time, fireAt time.Time, f func(ctx context.Context, fireAt time.Time, args *WorkArgs) error, clock clock.WithDelayedExecution) *TimedWorker { +func createWorker(ctx context.Context, wg *sync.WaitGroup, args *WorkArgs, createdAt time.Time, fireAt time.Time, f func(ctx context.Context, fireAt time.Time, args *WorkArgs) error, clock clock.WithDelayedExecution) *TimedWorker { delay := fireAt.Sub(createdAt) logger := klog.FromContext(ctx) - fWithErrorLogging := func() { - err := f(ctx, fireAt, args) - if err != nil { - logger.Error(err, "TaintEvictionController: timed worker failed") - } - } - if delay <= 0 { - go fWithErrorLogging() - return nil - } - timer := clock.AfterFunc(delay, fWithErrorLogging) - return &TimedWorker{ + + worker := TimedWorker{ WorkItem: args, CreatedAt: createdAt, FireAt: fireAt, - Timer: timer, } + + // This dance with the cancelled flag is here so that we can be sure that once TimedWorker.Cancel returns, + // we either never get to any processing, or the thread is already registered with the WaitGroup. + // Otherwise, the following sequence can happen: + // 1. TimedWorker.Timer fires and gets to go wrapper(). + // 2. TimedWorker.Cancel is called to stop the timer, but a goroutine is already running. + // It's started, but wg.Go hasn't been called yet to register the goroutine. + // 3. We call wg.Wait, which unblocks, because the inner wg.Go hasn't been called yet. + // This causes wg.Wait to unblock after TimedWorker.Cancel is called and still start a goroutine. + // So in our case we can still get to starting an unregistered goroutine, but it will exit immediately. + wrapper := func() { + wg.Go(func() { + if worker.cancelled.Load() { + return + } + if err := f(ctx, fireAt, args); err != nil { + logger.Error(err, "TaintEvictionController: timed worker failed") + } + }) + } + if delay <= 0 { + wrapper() + return nil + } + worker.Timer = clock.AfterFunc(delay, wrapper) + return &worker } // Cancel cancels the execution of function by the `TimedWorker` func (w *TimedWorker) Cancel() { if w != nil { + // Mark the worker as cancelled. + // This ensures the worker is either already running or unstarted on return from Cancel. + w.cancelled.Store(true) w.Timer.Stop() } } @@ -93,6 +113,7 @@ type TimedWorkerQueue struct { // map of workers keyed by string returned by 'KeyFromWorkArgs' from the given worker. // Entries may be nil if the work didn't need a timer and is already running. workers map[string]*TimedWorker + workerWG sync.WaitGroup workFunc func(ctx context.Context, fireAt time.Time, args *WorkArgs) error clock clock.WithDelayedExecution } @@ -134,7 +155,7 @@ func (q *TimedWorkerQueue) AddWork(ctx context.Context, args *WorkArgs, createdA return } logger.V(4).Info("Adding TimedWorkerQueue item and to be fired at firedTime", "item", key, "createTime", createdAt, "firedTime", fireAt) - worker := createWorker(ctx, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock) + worker := createWorker(ctx, &q.workerWG, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock) q.workers[key] = worker } @@ -159,7 +180,7 @@ func (q *TimedWorkerQueue) UpdateWork(ctx context.Context, args *WorkArgs, creat worker.Cancel() } logger.V(4).Info("Adding TimedWorkerQueue item and to be fired at firedTime", "item", key, "createTime", createdAt, "firedTime", fireAt) - worker := createWorker(ctx, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock) + worker := createWorker(ctx, &q.workerWG, args, createdAt, fireAt, q.getWrappedWorkerFunc(key), q.clock) q.workers[key] = worker } @@ -189,3 +210,15 @@ func (q *TimedWorkerQueue) GetWorkerUnsafe(key string) *TimedWorker { defer q.Unlock() return q.workers[key] } + +// CancelAndWait cancels all workers and waits for all running threads to terminate before returning. +func (q *TimedWorkerQueue) CancelAndWait() { + // Wait must be called after Unlock, otherwise this hangs. + defer q.workerWG.Wait() + q.Lock() + defer q.Unlock() + for _, worker := range q.workers { + worker.Cancel() + } + q.workers = make(map[string]*TimedWorker) +} diff --git a/pkg/controller/tainteviction/timed_workers_test.go b/pkg/controller/tainteviction/timed_workers_test.go index d39acaed443..89e2e6e4ffd 100644 --- a/pkg/controller/tainteviction/timed_workers_test.go +++ b/pkg/controller/tainteviction/timed_workers_test.go @@ -116,7 +116,7 @@ func TestCancel(t *testing.T) { } } -func TestCancelAndReadd(t *testing.T) { +func TestCancelAndRead(t *testing.T) { logger, ctx := ktesting.NewTestContext(t) testVal := int32(0) wg := sync.WaitGroup{}