/* Copyright 2024 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package testdeviceplugin import ( "context" "fmt" "net" "os" "sync" "time" "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" kubeletdevicepluginv1beta1 "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" "k8s.io/kubernetes/pkg/cluster/ports" "k8s.io/kubernetes/test/e2e/framework" e2enode "k8s.io/kubernetes/test/e2e/framework/node" ) var ( kubeletHealthCheckURL = fmt.Sprintf("http://127.0.0.1:%d/healthz", ports.KubeletHealthzPort) ) type DevicePlugin struct { server *grpc.Server uniqueName string devices []*kubeletdevicepluginv1beta1.Device devicesSync sync.Mutex devicesUpdateCh chan struct{} calls []string callsSync sync.Mutex errorInjector func(string) error kubeletdevicepluginv1beta1.UnsafeDevicePluginServer } func NewDevicePlugin(errorInjector func(string) error) *DevicePlugin { return &DevicePlugin{ calls: []string{}, devicesUpdateCh: make(chan struct{}), errorInjector: errorInjector, } } func (dp *DevicePlugin) GetDevicePluginOptions(context.Context, *kubeletdevicepluginv1beta1.Empty) (*kubeletdevicepluginv1beta1.DevicePluginOptions, error) { // lock the mutex and add to a list of calls dp.callsSync.Lock() dp.calls = append(dp.calls, "GetDevicePluginOptions") dp.callsSync.Unlock() if dp.errorInjector != nil { return &kubeletdevicepluginv1beta1.DevicePluginOptions{}, dp.errorInjector("GetDevicePluginOptions") } return &kubeletdevicepluginv1beta1.DevicePluginOptions{}, nil } func (dp *DevicePlugin) sendDevices(stream kubeletdevicepluginv1beta1.DevicePlugin_ListAndWatchServer) error { resp := new(kubeletdevicepluginv1beta1.ListAndWatchResponse) dp.devicesSync.Lock() for _, d := range dp.devices { resp.Devices = append(resp.Devices, d) } dp.devicesSync.Unlock() return stream.Send(resp) } func (dp *DevicePlugin) ListAndWatch(empty *kubeletdevicepluginv1beta1.Empty, stream kubeletdevicepluginv1beta1.DevicePlugin_ListAndWatchServer) error { dp.callsSync.Lock() dp.calls = append(dp.calls, "ListAndWatch") dp.callsSync.Unlock() if dp.errorInjector != nil { if err := dp.errorInjector("ListAndWatch"); err != nil { return err } } if err := dp.sendDevices(stream); err != nil { return err } // when the devices are updated, send the new devices to the kubelet // also start a timer and every second call into the dp.errorInjector // to simulate a device plugin failure ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { case <-dp.devicesUpdateCh: if err := dp.sendDevices(stream); err != nil { return err } case <-ticker.C: if dp.errorInjector != nil { if err := dp.errorInjector("ListAndWatch"); err != nil { return err } } } } } func (dp *DevicePlugin) Allocate(ctx context.Context, request *kubeletdevicepluginv1beta1.AllocateRequest) (*kubeletdevicepluginv1beta1.AllocateResponse, error) { result := new(kubeletdevicepluginv1beta1.AllocateResponse) dp.callsSync.Lock() dp.calls = append(dp.calls, "Allocate") dp.callsSync.Unlock() for _, r := range request.ContainerRequests { response := &kubeletdevicepluginv1beta1.ContainerAllocateResponse{} for _, id := range r.DevicesIds { fpath, err := os.CreateTemp("/tmp", fmt.Sprintf("%s-%s", dp.uniqueName, id)) gomega.Expect(err).To(gomega.Succeed()) response.Mounts = append(response.Mounts, &kubeletdevicepluginv1beta1.Mount{ ContainerPath: fpath.Name(), HostPath: fpath.Name(), }) } result.ContainerResponses = append(result.ContainerResponses, response) } return result, nil } func (dp *DevicePlugin) PreStartContainer(ctx context.Context, request *kubeletdevicepluginv1beta1.PreStartContainerRequest) (*kubeletdevicepluginv1beta1.PreStartContainerResponse, error) { return nil, nil } func (dp *DevicePlugin) GetPreferredAllocation(ctx context.Context, request *kubeletdevicepluginv1beta1.PreferredAllocationRequest) (*kubeletdevicepluginv1beta1.PreferredAllocationResponse, error) { return nil, nil } func (dp *DevicePlugin) RegisterDevicePlugin(ctx context.Context, uniqueName, resourceName string, devices []*kubeletdevicepluginv1beta1.Device) error { ginkgo.GinkgoHelper() dp.devicesSync.Lock() dp.devices = devices dp.devicesSync.Unlock() devicePluginEndpoint := fmt.Sprintf("%s-%s.sock", "test-device-plugin", uniqueName) dp.uniqueName = uniqueName ginkgo.By("Ensuring kubelet is healthy") gomega.Eventually(ctx, func() bool { ok := e2enode.HealthCheck(kubeletHealthCheckURL) framework.Logf("kubelet health check at %q value=%v", kubeletHealthCheckURL, ok) return ok }, framework.PodStartTimeout, framework.Poll).Should(gomega.BeTrueBecause("expected kubelet health check to be successful")) // Implement the logic to register the device plugin with the kubelet ginkgo.By("Create a listener on a specific port") lis, err := net.Listen("unix", kubeletdevicepluginv1beta1.DevicePluginPath+devicePluginEndpoint) if err != nil { return fmt.Errorf("failed to listen on the unix socket, error: %w", err) } ginkgo.By("Wait enough for unix socket to be open") time.Sleep(time.Second) ginkgo.By("Create a new gRPC server") dp.server = grpc.NewServer() ginkgo.By("Register the device plugin with the server") kubeletdevicepluginv1beta1.RegisterDevicePluginServer(dp.server, dp) ginkgo.By("Start the gRPC server") go func() { err := dp.server.Serve(lis) gomega.Expect(err).To(gomega.Succeed()) }() ginkgo.By("Create a connection to the kubelet") options := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithIdleTimeout(5 * time.Second), } ginkgo.By("gRPC server listening at socketPath") conn, err := grpc.NewClient("unix://"+kubeletdevicepluginv1beta1.KubeletSocket, options...) if err != nil { return err } defer func() { err := conn.Close() gomega.Expect(err).To(gomega.Succeed()) }() ginkgo.By("Wait enough for gRPC server unix scoket to be open") time.Sleep(time.Second) ginkgo.By("Create a client for the kubelet") client := kubeletdevicepluginv1beta1.NewRegistrationClient(conn) reqt := &kubeletdevicepluginv1beta1.RegisterRequest{ Version: kubeletdevicepluginv1beta1.Version, Endpoint: devicePluginEndpoint, ResourceName: resourceName, } ginkgo.By("Register the device plugin with the kubelet") _, err = client.Register(ctx, reqt) if err != nil { return err } return nil } func (dp *DevicePlugin) Stop() { ginkgo.By("Waiting one second before stoping gRPC server") time.Sleep(time.Second) if dp.server != nil { ginkgo.By("Stoping gRPC server stopped") dp.server.Stop() ginkgo.By("Fake gRPC server stopped") dp.server = nil } } func (dp *DevicePlugin) WasCalled(method string) bool { // lock mutex and then search if the method was called dp.callsSync.Lock() defer dp.callsSync.Unlock() for _, call := range dp.calls { if call == method { return true } } return false } func (dp *DevicePlugin) Calls() []string { // lock the mutex and return the calls dp.callsSync.Lock() defer dp.callsSync.Unlock() // return a copy of the calls calls := make([]string, len(dp.calls)) copy(calls, dp.calls) return calls } func (dp *DevicePlugin) UpdateDevices(devices []*kubeletdevicepluginv1beta1.Device) { // lock the mutex and update the devices dp.devicesSync.Lock() defer dp.devicesSync.Unlock() dp.devices = devices dp.devicesUpdateCh <- struct{}{} }