diff --git a/pkg/agent/run.go b/pkg/agent/run.go index 3d2436198dc..d9def1e1f9f 100644 --- a/pkg/agent/run.go +++ b/pkg/agent/run.go @@ -26,6 +26,8 @@ import ( "github.com/k3s-io/k3s/pkg/daemons/agent" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/daemons/executor" + "github.com/k3s-io/k3s/pkg/daemons/health" + "github.com/k3s-io/k3s/pkg/daemons/watchdog" "github.com/k3s-io/k3s/pkg/metrics" "github.com/k3s-io/k3s/pkg/nodeconfig" "github.com/k3s-io/k3s/pkg/profile" @@ -41,6 +43,7 @@ import ( "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/watch" + "k8s.io/apiserver/pkg/server/healthz" "k8s.io/client-go/kubernetes" toolscache "k8s.io/client-go/tools/cache" toolswatch "k8s.io/client-go/tools/watch" @@ -141,6 +144,14 @@ func run(ctx context.Context, cfg cmds.Agent, proxy proxy.Proxy) error { notifySocket := os.Getenv("NOTIFY_SOCKET") os.Unsetenv("NOTIFY_SOCKET") + // Capture the watchdog interval before stripping WATCHDOG_USEC so the + // kubelet's own NewHealthChecker short-circuits and doesn't spawn a + // goroutine that logs "Failed to notify watchdog" every tick. + watchdogInterval, werr := systemd.SdWatchdogEnabled(true) + if werr != nil { + logrus.Warnf("systemd watchdog: failed to read WATCHDOG_USEC, watchdog disabled: %v", werr) + } + go func() { if err := startCRI(ctx, nodeConfig); err != nil { signals.RequestShutdown(errors.WithMessage(err, "failed to start container runtime")) @@ -164,12 +175,27 @@ func run(ctx context.Context, cfg cmds.Agent, proxy proxy.Proxy) error { logrus.Info(version.Program + " agent is up and running") os.Setenv("NOTIFY_SOCKET", notifySocket) systemd.SdNotify(true, "READY=1\n") + go watchdog.Run(ctx, notifySocket, watchdogInterval, agentHealthCheckers(nodeConfig)) } }() return nil } +// agentHealthCheckers returns the set of liveness checks that must all pass +// for the k3s agent process to be considered healthy by the systemd +// watchdog. Only used on agent-only nodes; on server nodes the server's +// checker set covers kubelet (and kube-proxy when gated by DisableKubeProxy +// on Control). kube-proxy is not included here because its disabled state is +// not stored on nodeConfig — adding it unconditionally would silence the +// watchdog whenever kube-proxy is disabled. +func agentHealthCheckers(nodeConfig *daemonconfig.Node) []healthz.HealthChecker { + return []healthz.HealthChecker{ + health.NewHTTPGetHealthz("kubelet", "http://127.0.0.1:10248/healthz"), + health.NewGRPCHealthz("cri", nodeConfig.AgentConfig.RuntimeSocket), + } +} + // startCRI starts the configured CRI, or waits for an external CRI to be ready. func startCRI(ctx context.Context, nodeConfig *daemonconfig.Node) error { if nodeConfig.Docker { diff --git a/pkg/cli/cmds/log_linux.go b/pkg/cli/cmds/log_linux.go index 513e94db05d..389673f5d40 100644 --- a/pkg/cli/cmds/log_linux.go +++ b/pkg/cli/cmds/log_linux.go @@ -57,7 +57,11 @@ func forkIfLoggingOrReaping() error { } args := append([]string{version.Program}, os.Args[1:]...) - env := append(os.Environ(), "_K3S_LOG_REEXEC_=true", "NOTIFY_SOCKET=") + // NOTIFY_SOCKET is intentionally passed through to the child so that + // pkg/daemons/watchdog can drive systemd's WATCHDOG=1 pings. The child + // strips NOTIFY_SOCKET from its env early (server.go / agent run.go) + // before any embedded component can read it. + env := append(os.Environ(), "_K3S_LOG_REEXEC_=true") ctx := signals.SetupSignalContext() cmd := exec.CommandContext(ctx, "/proc/self/exe") cmd.Args = args diff --git a/pkg/cli/server/server.go b/pkg/cli/server/server.go index b6e3257b328..c09ceb8bca0 100644 --- a/pkg/cli/server/server.go +++ b/pkg/cli/server/server.go @@ -6,6 +6,7 @@ import ( "net" "os" "path/filepath" + "strconv" "strings" "sync" "time" @@ -18,6 +19,8 @@ import ( "github.com/k3s-io/k3s/pkg/clientaccess" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/daemons/executor" + "github.com/k3s-io/k3s/pkg/daemons/health" + "github.com/k3s-io/k3s/pkg/daemons/watchdog" "github.com/k3s-io/k3s/pkg/datadir" "github.com/k3s-io/k3s/pkg/etcd" k3smetrics "github.com/k3s-io/k3s/pkg/metrics" @@ -39,6 +42,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/apiserver/pkg/server/healthz" kubeapiserverflag "k8s.io/component-base/cli/flag" "k8s.io/klog/v2" "k8s.io/kubernetes/pkg/controlplane/apiserver/options" @@ -476,6 +480,14 @@ func run(app *cli.Context, cfg *cmds.Server, leaderControllers server.CustomCont notifySocket := os.Getenv("NOTIFY_SOCKET") os.Unsetenv("NOTIFY_SOCKET") + // Capture the watchdog interval before stripping WATCHDOG_USEC so the + // kubelet's own NewHealthChecker short-circuits and doesn't spawn a + // goroutine that logs "Failed to notify watchdog" every tick. + watchdogInterval, err := systemd.SdWatchdogEnabled(true) + if err != nil { + logrus.Warnf("systemd watchdog: failed to read WATCHDOG_USEC, watchdog disabled: %v", err) + } + // try setting advertise-ip from agent VPN if vpnInfo, _ := vpn.GetInfoFromExecutor(); vpnInfo != nil { // If we are in ipv6-only mode, we should pass the ipv6 address. Otherwise, ipv4 @@ -615,11 +627,53 @@ func run(app *cli.Context, cfg *cmds.Server, leaderControllers server.CustomCont logrus.Info(version.Program + " is up and running") os.Setenv("NOTIFY_SOCKET", notifySocket) systemd.SdNotify(true, "READY=1\n") + go watchdog.Run(ctx, notifySocket, watchdogInterval, serverHealthCheckers(&serverConfig.ControlConfig, cfg.DisableAgent)) }() return server.StartServer(ctx, wg, &serverConfig, cfg) } +// serverHealthCheckers returns the liveness checks that must all pass for the +// k3s server process to be considered healthy by the systemd watchdog. +// Components disabled via --disable-* are skipped. When the embedded agent is +// running on this node, kubelet and kube-proxy are included too, since the +// same process owns their liveness. Endpoints mirror the RKE2 static-pod +// readiness/liveness probes: +// +// https://github.com/rancher/rke2/blob/v1.36.0+rke2r1/pkg/podtemplate/spec.go +func serverHealthCheckers(cc *config.Control, agentDisabled bool) []healthz.HealthChecker { + var checkers []healthz.HealthChecker + host := cc.Loopback(false) + + if !cc.DisableAPIServer && cc.HTTPSPort > 0 { + // k3s starts kube-apiserver with --anonymous-auth=false, so /livez + // requires a client cert; the admin client cert is the simplest + // in-process identity that satisfies authn + authorization. + url := fmt.Sprintf("https://%s/livez", net.JoinHostPort(host, strconv.Itoa(cc.HTTPSPort))) + checkers = append(checkers, health.NewHTTPGetWithClientCertHealthz("kube-apiserver", url, + cc.Runtime.ClientAdminCert, cc.Runtime.ClientAdminKey)) + } + if !cc.DisableETCD { + checkers = append(checkers, health.NewHTTPGetHealthz("etcd", fmt.Sprintf("http://%s/health?serializable=true", net.JoinHostPort(host, "2381")))) + } + if !cc.DisableControllerManager { + checkers = append(checkers, health.NewHTTPGetHealthz("kube-controller-manager", fmt.Sprintf("https://%s/healthz", net.JoinHostPort(host, "10257")))) + } + if !cc.DisableScheduler { + checkers = append(checkers, health.NewHTTPGetHealthz("kube-scheduler", fmt.Sprintf("https://%s/healthz", net.JoinHostPort(host, "10259")))) + } + if cc.SupervisorPort > 0 { + checkers = append(checkers, health.NewHTTPGetHealthz("supervisor", fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(cc.SupervisorPort))))) + } + if !agentDisabled { + checkers = append(checkers, health.NewHTTPGetHealthz("kubelet", fmt.Sprintf("http://%s/healthz", net.JoinHostPort(host, "10248")))) + if !cc.DisableKubeProxy { + checkers = append(checkers, health.NewHTTPGetHealthz("kube-proxy", fmt.Sprintf("http://%s/healthz", net.JoinHostPort(host, "10256")))) + } + } + return checkers +} + // validateNetworkConfig ensures that the network configuration values make sense. func validateNetworkConfiguration(serverConfig server.Config) error { switch serverConfig.ControlConfig.EgressSelectorMode { diff --git a/pkg/daemons/health/health.go b/pkg/daemons/health/health.go new file mode 100644 index 00000000000..824fd676e29 --- /dev/null +++ b/pkg/daemons/health/health.go @@ -0,0 +1,115 @@ +package health + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "k8s.io/apiserver/pkg/server/healthz" +) + +const dialTimeout = 5 * time.Second + +func NewHTTPGetHealthz(name, url string) healthz.HealthChecker { + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + DisableKeepAlives: true, + }, + } + return healthz.NamedCheck(name, func(_ *http.Request) error { + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("get %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("get %s: status %d", url, resp.StatusCode) + } + return nil + }) +} + +func NewHTTPGetWithClientCertHealthz(name, url, certFile, keyFile string) healthz.HealthChecker { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return healthz.NamedCheck(name, func(_ *http.Request) error { + return fmt.Errorf("load client cert %s/%s: %w", certFile, keyFile, err) + }) + } + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, + }, + DisableKeepAlives: true, + }, + } + return healthz.NamedCheck(name, func(_ *http.Request) error { + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("get %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("get %s: status %d", url, resp.StatusCode) + } + return nil + }) +} + +func NewTCPConnectHealthz(name, addr string) healthz.HealthChecker { + return healthz.NamedCheck(name, func(_ *http.Request) error { + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("dial tcp %s: %w", addr, err) + } + return conn.Close() + }) +} + +func NewGRPCHealthz(name, target string) healthz.HealthChecker { + target = strings.TrimPrefix(target, "unix://") + dialTarget := "unix:" + target + return healthz.NamedCheck(name, func(_ *http.Request) error { + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + conn, err := grpc.NewClient(dialTarget, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return fmt.Errorf("dial grpc %s: %w", dialTarget, err) + } + defer conn.Close() + client := healthpb.NewHealthClient(conn) + resp, err := client.Check(ctx, &healthpb.HealthCheckRequest{}) + if err != nil { + return fmt.Errorf("grpc health check %s: %w", dialTarget, err) + } + if resp.Status != healthpb.HealthCheckResponse_SERVING { + return fmt.Errorf("grpc health check %s: status %s", dialTarget, resp.Status) + } + return nil + }) +} diff --git a/pkg/daemons/health/health_test.go b/pkg/daemons/health/health_test.go new file mode 100644 index 00000000000..fcadf8bf1ff --- /dev/null +++ b/pkg/daemons/health/health_test.go @@ -0,0 +1,206 @@ +package health + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "google.golang.org/grpc" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +func Test_UnitTCPChecker(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { ln.Close() }) + + c := NewTCPConnectHealthz("ok", ln.Addr().String()) + if c.Name() != "ok" { + t.Errorf("Name() = %q, want %q", c.Name(), "ok") + } + if err := c.Check(nil); err != nil { + t.Errorf("expected open port to pass, got %v", err) + } + ln.Close() + if err := NewTCPConnectHealthz("closed", ln.Addr().String()).Check(nil); err == nil { + t.Errorf("expected closed port to fail") + } +} + +func Test_UnitHTTPGetChecker(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/ok": + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusInternalServerError) + } + })) + t.Cleanup(srv.Close) + + if err := NewHTTPGetHealthz("ok", srv.URL+"/ok").Check(nil); err != nil { + t.Errorf("expected 200 to pass, got %v", err) + } + if err := NewHTTPGetHealthz("fail", srv.URL+"/fail").Check(nil); err == nil { + t.Errorf("expected 500 to fail") + } + if err := NewHTTPGetHealthz("dial", "http://127.0.0.1:1/never").Check(nil); err == nil { + t.Errorf("expected unreachable URL to fail") + } +} + +func Test_UnitHTTPGetCheckerSkipsTLSVerify(t *testing.T) { + // httptest.NewTLSServer uses a self-signed cert; the checker must accept it. + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + if err := NewHTTPGetHealthz("tls", srv.URL+"/livez").Check(nil); err != nil { + t.Errorf("expected TLS-skip-verify probe to pass against self-signed cert, got %v", err) + } +} + +func Test_UnitHTTPGetWithClientCertChecker(t *testing.T) { + // TLS server that requires (and inspects) a client cert. + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(r.TLS.PeerCertificates) == 0 { + http.Error(w, "no client cert", http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusOK) + })) + srv.TLS = &tls.Config{ClientAuth: tls.RequireAnyClientCert} + srv.StartTLS() + t.Cleanup(srv.Close) + + // Generate a throwaway client cert + key on disk. + certPath, keyPath := writeSelfSignedCert(t) + + if err := NewHTTPGetWithClientCertHealthz("apiserver", srv.URL+"/livez", certPath, keyPath).Check(nil); err != nil { + t.Errorf("expected probe with client cert to pass, got %v", err) + } + + // Without a cert the same endpoint should 401. + if err := NewHTTPGetHealthz("apiserver-anon", srv.URL+"/livez").Check(nil); err == nil { + t.Errorf("expected anonymous probe to fail without client cert") + } +} + +func Test_UnitHTTPGetWithClientCertMissingFiles(t *testing.T) { + // Cert path doesn't exist — Check should return an error every call, + // not panic. + c := NewHTTPGetWithClientCertHealthz("apiserver", "https://127.0.0.1:1/livez", "/nonexistent.crt", "/nonexistent.key") + if err := c.Check(nil); err == nil { + t.Errorf("expected missing cert files to surface as a Check error") + } +} + +// writeSelfSignedCert generates an in-memory self-signed cert and writes the +// PEM-encoded cert and key to a tempdir. Returns their paths. +func writeSelfSignedCert(t *testing.T) (certPath, keyPath string) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test-client"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + dir := t.TempDir() + certPath = filepath.Join(dir, "client.crt") + keyPath = filepath.Join(dir, "client.key") + if err := os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), 0600); err != nil { + t.Fatalf("write cert: %v", err) + } + keyDER, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + t.Fatalf("marshal key: %v", err) + } + if err := os.WriteFile(keyPath, pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyDER}), 0600); err != nil { + t.Fatalf("write key: %v", err) + } + return certPath, keyPath +} + +// healthServer is a minimal implementation of grpc.health.v1.Health that +// returns a configurable status — used to test the gRPC health checker +// against both SERVING and NOT_SERVING responses. +type healthServer struct { + healthpb.UnimplementedHealthServer + status healthpb.HealthCheckResponse_ServingStatus +} + +func (h *healthServer) Check(_ context.Context, _ *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return &healthpb.HealthCheckResponse{Status: h.status}, nil +} + +func startHealthServer(t *testing.T, status healthpb.HealthCheckResponse_ServingStatus) string { + t.Helper() + dir := t.TempDir() + socket := filepath.Join(dir, "grpc.sock") + + ln, err := net.Listen("unix", socket) + if err != nil { + t.Fatalf("listen unix: %v", err) + } + srv := grpc.NewServer() + healthpb.RegisterHealthServer(srv, &healthServer{status: status}) + go srv.Serve(ln) + t.Cleanup(func() { + srv.Stop() + ln.Close() + }) + return socket +} + +func Test_UnitGRPCCheckerServing(t *testing.T) { + socket := startHealthServer(t, healthpb.HealthCheckResponse_SERVING) + if err := NewGRPCHealthz("cri", socket).Check(nil); err != nil { + t.Errorf("expected SERVING status to pass, got %v", err) + } +} + +func Test_UnitGRPCCheckerNotServing(t *testing.T) { + socket := startHealthServer(t, healthpb.HealthCheckResponse_NOT_SERVING) + if err := NewGRPCHealthz("cri", socket).Check(nil); err == nil { + t.Errorf("expected NOT_SERVING status to fail") + } +} + +func Test_UnitGRPCCheckerStripsUnixScheme(t *testing.T) { + socket := startHealthServer(t, healthpb.HealthCheckResponse_SERVING) + if err := NewGRPCHealthz("cri", "unix://"+socket).Check(nil); err != nil { + t.Errorf("expected unix:// scheme to be stripped, got %v", err) + } +} + +func Test_UnitGRPCCheckerMissingSocket(t *testing.T) { + dir := t.TempDir() + socket := filepath.Join(dir, "absent.sock") + if err := NewGRPCHealthz("cri", socket).Check(nil); err == nil { + t.Errorf("expected missing socket to fail") + } +} diff --git a/pkg/daemons/watchdog/watchdog.go b/pkg/daemons/watchdog/watchdog.go new file mode 100644 index 00000000000..bb395b42f59 --- /dev/null +++ b/pkg/daemons/watchdog/watchdog.go @@ -0,0 +1,93 @@ +// Package watchdog implements the k3s side of the systemd notify / watchdog +// protocol. +// +// k3s strips NOTIFY_SOCKET (and WATCHDOG_USEC) from the process environment +// early in startup so embedded components — kubelet, etcd, kine, etc. — +// cannot ping systemd on behalf of the whole process. That is intentional: +// the kubelet by itself has no visibility into etcd, the apiserver, or the +// CRI runtime, so letting it ping the watchdog would mask whole-process +// failures. +// +// READY=1 is still sent the usual way via systemd.SdNotify in the server / +// agent startup code, which temporarily restores NOTIFY_SOCKET and then +// unsets it again. This package owns the periodic WATCHDOG=1 pings: callers +// pass in the cached NOTIFY_SOCKET path and WATCHDOG_USEC interval that they +// captured before stripping, plus the set of healthz.HealthCheckers covering +// every component that must be alive for the process to be considered +// healthy. WATCHDOG=1 is only sent while every check passes; otherwise the +// loop stays quiet and systemd will restart the unit after WatchdogSec. +package watchdog + +import ( + "context" + "errors" + "net" + "time" + + systemd "github.com/coreos/go-systemd/v22/daemon" + "github.com/sirupsen/logrus" + "k8s.io/apiserver/pkg/server/healthz" +) + +func Run(ctx context.Context, socketPath string, interval time.Duration, checkers []healthz.HealthChecker) { + if socketPath == "" { + return + } + if interval <= 0 { + logrus.Debug("systemd watchdog: not enabled by unit, notifier disabled") + return + } + if len(checkers) == 0 { + logrus.Warn("systemd watchdog: no health checks registered, notifier disabled") + return + } + + tick := interval / 2 + names := make([]string, len(checkers)) + for i, c := range checkers { + names[i] = c.Name() + } + logrus.Infof("systemd watchdog: pinging every %s (WatchdogSec=%s), monitoring components %v", + tick, interval, names) + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if name, err := check(checkers); err != nil { + logrus.Warnf("systemd watchdog: %q is unhealthy, withholding WATCHDOG=1: %v", name, err) + continue + } + if err := notify(socketPath, systemd.SdNotifyWatchdog); err != nil { + logrus.Warnf("systemd watchdog: failed to send WATCHDOG=1: %v", err) + } + } + } +} + +func check(checkers []healthz.HealthChecker) (string, error) { + for _, c := range checkers { + if err := c.Check(nil); err != nil { + return c.Name(), err + } + } + return "", nil +} + +func notify(socketPath, state string) error { + if socketPath == "" { + return errors.New("watchdog: empty notify socket path") + } + addr := &net.UnixAddr{Name: socketPath, Net: "unixgram"} + conn, err := net.DialUnix(addr.Net, nil, addr) + if err != nil { + return err + } + defer conn.Close() + _, err = conn.Write([]byte(state)) + return err +} diff --git a/pkg/daemons/watchdog/watchdog_test.go b/pkg/daemons/watchdog/watchdog_test.go new file mode 100644 index 00000000000..512e14aeef6 --- /dev/null +++ b/pkg/daemons/watchdog/watchdog_test.go @@ -0,0 +1,158 @@ +package watchdog + +import ( + "context" + "errors" + "net" + "net/http" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "k8s.io/apiserver/pkg/server/healthz" +) + +// startNotifyListener opens a unix datagram socket at a temporary path and +// returns the path plus a channel that receives every datagram written to it. +// The listener is cleaned up when the test ends. +func startNotifyListener(t *testing.T) (string, <-chan string) { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "notify.sock") + conn, err := net.ListenUnixgram("unixgram", &net.UnixAddr{Name: path, Net: "unixgram"}) + if err != nil { + t.Fatalf("listen unixgram: %v", err) + } + t.Cleanup(func() { + conn.Close() + _ = os.Remove(path) + }) + + out := make(chan string, 16) + go func() { + buf := make([]byte, 1024) + for { + n, _, err := conn.ReadFromUnix(buf) + if err != nil { + close(out) + return + } + out <- string(buf[:n]) + } + }() + return path, out +} + +// fakeChecker returns a healthz.HealthChecker that calls fn and exposes a +// counter of invocations. +type fakeChecker struct { + name string + fn func() error + calls atomic.Int32 +} + +func (f *fakeChecker) Name() string { return f.name } +func (f *fakeChecker) Check(_ *http.Request) error { f.calls.Add(1); return f.fn() } + +func ok(name string) *fakeChecker { + return &fakeChecker{name: name, fn: func() error { return nil }} +} + +func bad(name string, err error) *fakeChecker { + return &fakeChecker{name: name, fn: func() error { return err }} +} + +func Test_UnitWatchdogNoSocketReturnsImmediately(t *testing.T) { + done := make(chan struct{}) + go func() { + Run(context.Background(), "", time.Second, []healthz.HealthChecker{ok("x")}) + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Run did not return when socketPath was empty") + } +} + +func Test_UnitWatchdogZeroIntervalReturnsImmediately(t *testing.T) { + socket, _ := startNotifyListener(t) + done := make(chan struct{}) + go func() { + Run(context.Background(), socket, 0, []healthz.HealthChecker{ok("x")}) + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Run did not return when interval was zero") + } +} + +func Test_UnitWatchdogEmptyCheckersReturnsImmediately(t *testing.T) { + socket, _ := startNotifyListener(t) + done := make(chan struct{}) + go func() { + Run(context.Background(), socket, time.Second, nil) + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Run did not return when checkers was empty") + } +} + +func Test_UnitWatchdogPingsWhenHealthy(t *testing.T) { + socket, msgs := startNotifyListener(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go Run(ctx, socket, 100*time.Millisecond, []healthz.HealthChecker{ok("ok")}) + + select { + case got := <-msgs: + if got != "WATCHDOG=1" { + t.Errorf("expected WATCHDOG=1, got %q", got) + } + case <-time.After(2 * time.Second): + t.Fatal("did not receive WATCHDOG=1 within 2s") + } +} + +func Test_UnitWatchdogWithholdsPingWhenUnhealthy(t *testing.T) { + socket, msgs := startNotifyListener(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + checker := bad("bad", errors.New("unhealthy")) + go Run(ctx, socket, 100*time.Millisecond, []healthz.HealthChecker{checker}) + + select { + case got := <-msgs: + t.Fatalf("did not expect any WATCHDOG=1 ping, got %q", got) + case <-time.After(500 * time.Millisecond): + } + if checker.calls.Load() == 0 { + t.Fatal("expected checker to have been invoked at least once") + } +} + +func Test_UnitWatchdogStopsOnContextCancel(t *testing.T) { + socket, _ := startNotifyListener(t) + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + Run(ctx, socket, 100*time.Millisecond, []healthz.HealthChecker{ok("ok")}) + close(done) + }() + + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Run did not return after ctx cancellation") + } +}