From b255410b4f6197418b18293b55e663ac820e8518 Mon Sep 17 00:00:00 2001 From: k-diger Date: Tue, 9 Sep 2025 21:16:44 +0900 Subject: [PATCH] Remove duplicate connection management in DRA plugin Fixes --- pkg/kubelet/cm/dra/plugin/dra_plugin.go | 74 +----- pkg/kubelet/cm/dra/plugin/dra_plugin_test.go | 61 ++++- test/e2e_node/dra_test.go | 245 ++++++++++++++----- 3 files changed, 237 insertions(+), 143 deletions(-) diff --git a/pkg/kubelet/cm/dra/plugin/dra_plugin.go b/pkg/kubelet/cm/dra/plugin/dra_plugin.go index 1529bdcc3af..1765bbaaaab 100644 --- a/pkg/kubelet/cm/dra/plugin/dra_plugin.go +++ b/pkg/kubelet/cm/dra/plugin/dra_plugin.go @@ -18,15 +18,11 @@ package plugin import ( "context" - "errors" "fmt" - "net" "sync" "time" "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "k8s.io/klog/v2" @@ -65,70 +61,10 @@ type DRAPlugin struct { mutex sync.Mutex backgroundCtx context.Context - healthClient drahealthv1alpha1.DRAResourceHealthClient healthStreamCtx context.Context healthStreamCancel context.CancelFunc } -func (p *DRAPlugin) getOrCreateGRPCConn() (*grpc.ClientConn, error) { - p.mutex.Lock() - defer p.mutex.Unlock() - - // If connection exists and is ready, return it. - if p.conn != nil && p.conn.GetState() != connectivity.Shutdown { - // Initialize health client if connection exists but client is nil - // This allows lazy init if connection was established before health was added. - if p.healthClient == nil { - p.healthClient = drahealthv1alpha1.NewDRAResourceHealthClient(p.conn) - klog.FromContext(p.backgroundCtx).V(4).Info("Initialized DRAResourceHealthClient lazily") - } - return p.conn, nil - } - - // If the connection is dead, clean it up before creating a new one. - if p.conn != nil { - if err := p.conn.Close(); err != nil { - return nil, fmt.Errorf("failed to close stale gRPC connection to %s: %w", p.endpoint, err) - } - p.conn = nil - p.healthClient = nil - } - - ctx := p.backgroundCtx - logger := klog.FromContext(ctx) - - network := "unix" - logger.V(4).Info("Creating new gRPC connection", "protocol", network, "endpoint", p.endpoint) - // grpc.Dial is deprecated. grpc.NewClient should be used instead. - // For now this gets ignored because this function is meant to establish - // the connection, with the one second timeout below. Perhaps that - // approach should be reconsidered? - //nolint:staticcheck - conn, err := grpc.Dial( - p.endpoint, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, network, target) - }), - grpc.WithChainUnaryInterceptor(newMetricsInterceptor(p.driverName)), - ) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - if ok := conn.WaitForStateChange(ctx, connectivity.Connecting); !ok { - return nil, errors.New("timed out waiting for gRPC connection to be ready") - } - - p.conn = conn - p.healthClient = drahealthv1alpha1.NewDRAResourceHealthClient(p.conn) - - return p.conn, nil -} - func (p *DRAPlugin) DriverName() string { return p.driverName } @@ -221,17 +157,11 @@ func (p *DRAPlugin) HealthStreamCancel() context.CancelFunc { // NodeWatchResources establishes a stream to receive health updates from the DRA plugin. func (p *DRAPlugin) NodeWatchResources(ctx context.Context) (drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesClient, error) { - // Ensure a connection and the health client exist before proceeding. - // This call is idempotent and will create them if they don't exist. - _, err := p.getOrCreateGRPCConn() - if err != nil { - klog.FromContext(p.backgroundCtx).Error(err, "Failed to get gRPC connection for health client") - return nil, err - } + healthClient := drahealthv1alpha1.NewDRAResourceHealthClient(p.conn) logger := klog.FromContext(ctx).WithValues("pluginName", p.driverName) logger.V(4).Info("Starting WatchResources stream") - stream, err := p.healthClient.NodeWatchResources(ctx, &drahealthv1alpha1.NodeWatchResourcesRequest{}) + stream, err := healthClient.NodeWatchResources(ctx, &drahealthv1alpha1.NodeWatchResourcesRequest{}) if err != nil { logger.Error(err, "NodeWatchResources RPC call failed") return nil, err diff --git a/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go b/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go index 9a7c856bdb4..a63576896f3 100644 --- a/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go +++ b/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go @@ -129,9 +129,8 @@ func setupFakeGRPCServer(service, addr string) (tearDown, error) { func TestGRPCConnIsReused(t *testing.T) { tCtx := ktesting.Init(t) - service := drapbv1.DRAPluginService addr := path.Join(t.TempDir(), "dra.sock") - teardown, err := setupFakeGRPCServer(service, addr) + teardown, err := setupFakeGRPCServer("", addr) require.NoError(t, err) defer teardown() @@ -143,7 +142,7 @@ func TestGRPCConnIsReused(t *testing.T) { // ensure the plugin we are using is registered draPlugins := NewDRAPluginManager(tCtx, nil, nil, &mockStreamHandler{}, 0) - tCtx.ExpectNoError(draPlugins.add(driverName, addr, service, defaultClientCallTimeout), "add plugin") + tCtx.ExpectNoError(draPlugins.add(driverName, addr, drapbv1.DRAPluginService, defaultClientCallTimeout), "add plugin") plugin, err := draPlugins.GetPlugin(driverName) tCtx.ExpectNoError(err, "get plugin") conn := plugin.conn @@ -184,6 +183,62 @@ func TestGRPCConnIsReused(t *testing.T) { // We should have only one entry otherwise it means another gRPC connection has been created require.Len(t, reusedConns, 1, "expected length to be 1 but got %d", len(reusedConns)) require.Equal(t, 2, reusedConns[conn], "expected counter to be 2 but got %d", reusedConns[conn]) + + tCtx.Run("health_api_reuses_connection", func(tCtx ktesting.TContext) { + ctx, cancel := context.WithTimeout(tCtx, 5*time.Second) + defer cancel() + + originalConn := plugin.conn + + stream, err := plugin.NodeWatchResources(ctx) + require.NoError(tCtx, err, "Health stream should work") + require.NotNil(tCtx, stream) + + require.Equal(tCtx, originalConn, plugin.conn, "Health API should reuse the same connection") + + resp, err := stream.Recv() + require.NoError(tCtx, err, "Should receive health data") + require.NotNil(tCtx, resp) + require.Len(tCtx, resp.Devices, 1) + assert.Equal(tCtx, "pool1", resp.Devices[0].GetDevice().GetPoolName()) + assert.Equal(tCtx, "dev1", resp.Devices[0].GetDevice().GetDeviceName()) + assert.Equal(tCtx, drahealthv1alpha1.HealthStatus_HEALTHY, resp.Devices[0].GetHealth()) + + require.Equal(tCtx, originalConn, plugin.conn, "Connection should remain unchanged after health operations") + + prepareReq := &drapbv1.NodePrepareResourcesRequest{ + Claims: []*drapbv1.Claim{ + { + Namespace: "dummy-namespace", + Uid: "dummy-uid", + Name: "dummy-claim", + }, + }, + } + + prepareResp, err := plugin.NodePrepareResources(ctx, prepareReq) + require.NoError(tCtx, err, "NodePrepareResources should work") + require.NotNil(tCtx, prepareResp) + require.NotNil(tCtx, prepareResp.Claims["claim-uid"]) + + require.Equal(tCtx, originalConn, plugin.conn, "Connection should remain unchanged after NodePrepareResources") + + unprepareReq := &drapbv1.NodeUnprepareResourcesRequest{ + Claims: []*drapbv1.Claim{ + { + Namespace: "dummy-namespace", + Uid: "dummy-uid", + Name: "dummy-claim", + }, + }, + } + + unprepareResp, err := plugin.NodeUnprepareResources(ctx, unprepareReq) + require.NoError(tCtx, err, "NodeUnprepareResources should work") + require.NotNil(tCtx, unprepareResp) + + require.Equal(tCtx, originalConn, plugin.conn, "Connection should remain unchanged after all API calls") + }) } func TestGRPCConnUsableAfterIdle(t *testing.T) { diff --git a/test/e2e_node/dra_test.go b/test/e2e_node/dra_test.go index 183bcf50987..8254d370e48 100644 --- a/test/e2e_node/dra_test.go +++ b/test/e2e_node/dra_test.go @@ -876,28 +876,10 @@ var _ = framework.SIGDescribe("node")(framework.WithLabel("DRA"), feature.Dynami ginkgo.By("Starting the test driver with channel-based control") kubeletPlugin := newKubeletPlugin(ctx, f.ClientSet, f.Namespace.Name, getNodeName(ctx, f), driverName) - className := "health-test-class" - claimName := "health-test-claim" - podName := "health-test-pod" - poolNameForTest := "pool-a" - deviceNameForTest := "dev-0" - - pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolNameForTest, deviceNameForTest) - - ginkgo.By("Waiting for the pod to be running") - framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) - - ginkgo.By("Forcing a 'Healthy' status update to establish a baseline") - kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ - PoolName: poolNameForTest, - DeviceName: deviceNameForTest, - Health: "Healthy", - } - - ginkgo.By("Verifying device health is now Healthy in the pod status") - gomega.Eventually(ctx, func(ctx context.Context) (string, error) { - return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) - }).WithTimeout(30*time.Second).WithPolling(1*time.Second).Should(gomega.Equal("Healthy"), "Device health should be Healthy after explicit update") + pod, claimName, poolNameForTest, deviceNameForTest := setupAndVerifyHealthyPod( + ctx, f, kubeletPlugin, driverName, + "health-test-class", "health-test-claim", "health-test-pod", "pool-a", "dev-0", + ) ginkgo.By("Setting device health to Unhealthy via control channel") kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ @@ -956,34 +938,16 @@ var _ = framework.SIGDescribe("node")(framework.WithLabel("DRA"), feature.Dynami draService := newDRAService(ctx, f.ClientSet, f.Namespace.Name, nodeName, driverName, "", kubeletplugin.PluginListener(getListener)) for _, suffix := range []string{"-1", "-2"} { - className := "health-test-class" + suffix - claimName := "health-test-claim" + suffix - podName := "health-test-pod" + suffix - poolNameForTest := "pool-a" + suffix - deviceNameForTest := "dev-0" + suffix - draService.ResetGRPCCalls() - ginkgo.By("create test objects " + suffix) - pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolNameForTest, deviceNameForTest) - - ginkgo.By("wait for NodePrepareResources call to succeed " + suffix) - gomega.Eventually(draService.GetGRPCCalls).WithTimeout(retryTestTimeout).Should(testdrivergomega.NodePrepareResourcesSucceeded) - - ginkgo.By("Waiting for the pod to be running") - framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) - - ginkgo.By("Forcing a 'Healthy' status update to establish a baseline") - draService.HealthControlChan <- testdriver.DeviceHealthUpdate{ - PoolName: poolNameForTest, - DeviceName: deviceNameForTest, - Health: "Healthy", - } - - ginkgo.By("Verifying device health is now Healthy in the pod status") - gomega.Eventually(ctx, func(ctx context.Context) (string, error) { - return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) - }).WithTimeout(30*time.Second).WithPolling(1*time.Second).Should(gomega.Equal("Healthy"), "Device health should be Healthy after explicit update") + setupAndVerifyHealthyPod( + ctx, f, draService, driverName, + "health-test-class"+suffix, + "health-test-claim"+suffix, + "health-test-pod"+suffix, + "pool-a"+suffix, + "dev-0"+suffix, + ) ginkgo.By("check that listener.Accept was called only once " + suffix) gomega.Expect(listener.acceptCount()).To(gomega.Equal(1)) @@ -996,26 +960,10 @@ var _ = framework.SIGDescribe("node")(framework.WithLabel("DRA"), feature.Dynami ginkgo.By("Starting the test driver") kubeletPlugin := newKubeletPlugin(ctx, f.ClientSet, f.Namespace.Name, getNodeName(ctx, f), driverName) - className := "unknown-test-class" - claimName := "unknown-test-claim" - podName := "unknown-test-pod" - poolNameForTest := "pool-b" - deviceNameForTest := "dev-1" - - pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolNameForTest, deviceNameForTest) - - ginkgo.By("Waiting for the pod to be running") - framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) - - ginkgo.By("Establishing a baseline 'Healthy' status") - kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ - PoolName: poolNameForTest, - DeviceName: deviceNameForTest, - Health: "Healthy", - } - gomega.Eventually(ctx, func(ctx context.Context) (string, error) { - return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) - }).WithTimeout(30*time.Second).WithPolling(1*time.Second).Should(gomega.Equal("Healthy"), "Device health should be Healthy initially") + pod, claimName, poolNameForTest, deviceNameForTest := setupAndVerifyHealthyPod( + ctx, f, kubeletPlugin, driverName, + "unknown-test-class", "unknown-test-claim", "unknown-test-pod", "pool-b", "dev-1", + ) ginkgo.By("Stopping the DRA plugin to simulate a crash") kubeletPlugin.Stop() @@ -1042,6 +990,130 @@ var _ = framework.SIGDescribe("node")(framework.WithLabel("DRA"), feature.Dynami }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Healthy"), "Device health should recover to Healthy after plugin restarts") }) + // Reconnection verification after connection drop + ginkgo.It("should automatically reconnect both DRA and Health API after connection drop", func(ctx context.Context) { + ginkgo.By("Starting the test driver") + kubeletPlugin := newKubeletPlugin(ctx, f.ClientSet, f.Namespace.Name, getNodeName(ctx, f), driverName) + + ginkgo.By("Creating initial pod to establish connection") + pod, claimName, poolNameForTest, deviceNameForTest := setupAndVerifyHealthyPod( + ctx, f, kubeletPlugin, driverName, + "reconnect-test-class", "reconnect-test-claim", "reconnect-test-pod", "pool-reconnect", "dev-reconnect-0", + ) + + ginkgo.By("Simulating connection drop by stopping the plugin") + kubeletPlugin.Stop() + + ginkgo.By("Verifying health transitions to Unknown after connection drop") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(2*time.Minute).WithPolling(5*time.Second).Should(gomega.Equal("Unknown"), + "Health should become Unknown when plugin connection is lost") + + ginkgo.By("Restarting the plugin to simulate reconnection") + kubeletPlugin = newKubeletPlugin(ctx, f.ClientSet, f.Namespace.Name, getNodeName(ctx, f), driverName) + + ginkgo.By("Waiting for plugin registration after restart") + gomega.Eventually(kubeletPlugin.GetGRPCCalls).WithTimeout(pluginRegistrationTimeout).Should(testdrivergomega.BeRegistered, + "Plugin should re-register after restart") + + ginkgo.By("Creating a new pod to trigger NodePrepareResources after reconnection") + pod2, claimName2, poolNameForTest2, deviceNameForTest2 := setupAndVerifyHealthyPod( + ctx, f, kubeletPlugin, driverName, + "reconnect-test-class-2", "reconnect-test-claim-2", "reconnect-test-pod-2", "pool-reconnect-2", "dev-reconnect-1", + ) + + ginkgo.By("Verifying health updates continue to work") + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNameForTest2, + DeviceName: deviceNameForTest2, + Health: "Unhealthy", + } + + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod2.Namespace, pod2.Name, driverName, claimName2, poolNameForTest2, deviceNameForTest2) + }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Unhealthy"), + "Health API should continue working after reconnection") + }) + + // Concurrent operations verification + ginkgo.It("should handle concurrent DRA operations and health monitoring without connection issues", func(ctx context.Context) { + ginkgo.By("Starting the test driver") + kubeletPlugin := newKubeletPlugin(ctx, f.ClientSet, f.Namespace.Name, getNodeName(ctx, f), driverName) + + numPods := 3 + pods := make([]*v1.Pod, numPods) + claimNames := make([]string, numPods) + poolNames := make([]string, numPods) + deviceNames := make([]string, numPods) + + ginkgo.By(fmt.Sprintf("Creating %d pods concurrently to stress test connection management", numPods)) + for i := 0; i < numPods; i++ { + className := fmt.Sprintf("concurrent-class-%d", i) + claimNames[i] = fmt.Sprintf("concurrent-claim-%d", i) + podName := fmt.Sprintf("concurrent-pod-%d", i) + poolNames[i] = fmt.Sprintf("pool-concurrent-%d", i) + deviceNames[i] = fmt.Sprintf("dev-concurrent-%d", i) + + pods[i] = createHealthTestPodAndClaim(ctx, f, driverName, podName, claimNames[i], className, poolNames[i], deviceNames[i]) + } + + ginkgo.By("Waiting for all pods to be running") + for i, pod := range pods { + framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod), + fmt.Sprintf("Pod %d should be running", i)) + } + + ginkgo.By("Verifying NodePrepareResources was called for all pods") + gomega.Eventually(func() int { + return kubeletPlugin.CountCalls("/NodePrepareResources") + }).WithTimeout(retryTestTimeout).Should(gomega.BeNumerically(">=", numPods), + "NodePrepareResources should be called at least once per pod") + + ginkgo.By("Sending health updates for all devices concurrently") + for i := 0; i < numPods; i++ { + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNames[i], + DeviceName: deviceNames[i], + Health: "Healthy", + } + } + + ginkgo.By("Verifying all health updates are correctly reflected") + for i := 0; i < numPods; i++ { + pod := pods[i] + poolName := poolNames[i] + deviceName := deviceNames[i] + claimName := claimNames[i] + + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolName, deviceName) + }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Healthy"), + fmt.Sprintf("Health for device %s should be Healthy", deviceName)) + } + + ginkgo.By("Changing health status for all devices to verify continued operation") + for i := 0; i < numPods; i++ { + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNames[i], + DeviceName: deviceNames[i], + Health: "Unhealthy", + } + } + + ginkgo.By("Verifying all health changes are correctly reflected") + for i := 0; i < numPods; i++ { + pod := pods[i] + poolName := poolNames[i] + deviceName := deviceNames[i] + claimName := claimNames[i] + + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolName, deviceName) + }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Unhealthy"), + fmt.Sprintf("Health for device %s should be Unhealthy", deviceName)) + } + }) }) // This matches the "Resource Health" context above, except that it contains tests which need to run @@ -1613,6 +1685,43 @@ func createHealthTestPodAndClaim(ctx context.Context, f *framework.Framework, dr return createdPod } +// setupAndVerifyHealthyPod creates a test pod with health monitoring and verifies it reaches a healthy state. +// It encapsulates the repeated pattern of pod creation, waiting for NodePrepareResources, +// pod startup, and health verification that appears throughout the health monitoring tests. +func setupAndVerifyHealthyPod( + ctx context.Context, + f *framework.Framework, + plugin *testdriver.ExamplePlugin, + driverName string, + className string, + claimName string, + podName string, + poolName string, + deviceName string, +) (*v1.Pod, string, string, string) { + pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolName, deviceName) + + ginkgo.By("wait for NodePrepareResources call to succeed") + gomega.Eventually(plugin.GetGRPCCalls).WithTimeout(retryTestTimeout).Should(testdrivergomega.NodePrepareResourcesSucceeded) + + ginkgo.By("Waiting for the pod to be running") + framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) + + ginkgo.By("Forcing a 'Healthy' status update to establish a baseline") + plugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolName, + DeviceName: deviceName, + Health: "Healthy", + } + + ginkgo.By("Verifying device health is now Healthy in the pod status") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolName, deviceName) + }).WithTimeout(30*time.Second).WithPolling(1*time.Second).Should(gomega.Equal("Healthy"), "Device health should be Healthy after explicit update") + + return pod, claimName, poolName, deviceName +} + // errorOnCloseListener is a mock net.Listener that blocks on Accept() // until Close() is called, at which point Accept() returns a predefined error. //