From d34cb72e684be97ec2c84279686a1bdab606ef79 Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Wed, 11 Mar 2026 10:30:48 -0400 Subject: [PATCH] Add counting for SSH certs and OTPs (#12368) (#12755) * add cert counting for ssh * add system view and fix errors * add otp counting and change units for certs * add storage tests * fix census errors * run make fmt * use incrementer and change storage to match rfc * run make fmt * fix interface and remove parameter * fix errors * Update builtin/logical/ssh/path_creds_create.go * remove error check * add ssh counts to billing endpoint * fix error * add test case * add ssh metric to test * add get functions and tests * fix format * create function for ssh metrics * refactoring and add test cases * replace test check * add ssh to billing overview test --------- Co-authored-by: Rachel Culpepper <84159930+rculpepper@users.noreply.github.com> Co-authored-by: Victor Rodriguez Rizo --- api/sys_billing_test.go | 23 +- builtin/logical/ssh/backend.go | 15 +- builtin/logical/ssh/path_creds_create.go | 2 + builtin/logical/ssh/path_issue_sign.go | 2 + sdk/logical/certificate_counter.go | 27 +- vault/billing/billing_counts.go | 2 + vault/consumption_billing_util.go | 167 +++++++++++ vault/consumption_billing_util_test.go | 261 ++++++++++++++++++ vault/extended_system_view.go | 9 +- vault/external_tests/api/sys_billing_test.go | 1 + vault/logical_system_helpers.go | 15 + vault/logical_system_use_case_billing.go | 35 +++ vault/logical_system_use_case_billing_test.go | 78 ++++++ 13 files changed, 630 insertions(+), 7 deletions(-) diff --git a/api/sys_billing_test.go b/api/sys_billing_test.go index 7b816b3831..65b177b705 100644 --- a/api/sys_billing_test.go +++ b/api/sys_billing_test.go @@ -33,7 +33,7 @@ func TestSys_BillingOverview(t *testing.T) { currentMonth := resp.Months[0] require.Equal(t, "2026-01", currentMonth.Month) require.Equal(t, "2026-01-14T10:49:00Z", currentMonth.UpdatedAt) - require.Len(t, currentMonth.UsageMetrics, 8, "should have all 8 metrics") + require.Len(t, currentMonth.UsageMetrics, 9, "should have all 9 metrics") // Create a map to verify all expected metrics are present metricsMap := make(map[string]UsageMetric) @@ -51,6 +51,7 @@ func TestSys_BillingOverview(t *testing.T) { "data_protection_calls", "pki_units", "managed_keys", + "ssh_units", } for _, metricName := range expectedMetrics { @@ -86,6 +87,10 @@ func TestSys_BillingOverview(t *testing.T) { require.Equal(t, "external_plugins", externalPluginsMetric.MetricName) require.NotNil(t, externalPluginsMetric.MetricData) require.Contains(t, externalPluginsMetric.MetricData, "total") + + sshMetric := metricsMap["ssh_units"] + require.Contains(t, sshMetric.MetricData, "total") + require.Contains(t, sshMetric.MetricData, "metric_details") } func mockVaultBillingHandler(w http.ResponseWriter, _ *http.Request) { @@ -200,6 +205,22 @@ const billingOverviewResponse = `{ } ] } + }, + { + "metric_name": "ssh_units", + "metric_data": { + "total": 8.4, + "metric_details": [ + { + "type": "otp_units", + "count": 5 + }, + { + "type": "certificate_units", + "count": 3.4 + } + ] + } } ] }, diff --git a/builtin/logical/ssh/backend.go b/builtin/logical/ssh/backend.go index c64c98a47c..13aad152eb 100644 --- a/builtin/logical/ssh/backend.go +++ b/builtin/logical/ssh/backend.go @@ -18,10 +18,11 @@ const operationPrefixSSH = "ssh" type backend struct { *framework.Backend - view logical.Storage - salt *salt.Salt - saltMutex sync.RWMutex - backendUUID string + view logical.Storage + salt *salt.Salt + saltMutex sync.RWMutex + backendUUID string + sshCertificateCounter logical.CertificateCounter } func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { @@ -84,6 +85,12 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { BackendType: logical.TypeLogical, } + if sshCertCounterSysView, ok := conf.System.(logical.CertificateCountSystemView); ok { + b.sshCertificateCounter = sshCertCounterSysView.GetCertificateCounter() + } else { + b.sshCertificateCounter = logical.NewNullCertificateCounter() + } + b.backendUUID = conf.BackendUUID return &b, nil } diff --git a/builtin/logical/ssh/path_creds_create.go b/builtin/logical/ssh/path_creds_create.go index 0f27c6b5de..3827803369 100644 --- a/builtin/logical/ssh/path_creds_create.go +++ b/builtin/logical/ssh/path_creds_create.go @@ -203,6 +203,8 @@ func (b *backend) GenerateOTPCredential(ctx context.Context, req *logical.Reques if err := req.Storage.Put(ctx, newEntry); err != nil { return "", err } + + b.sshCertificateCounter.Increment().AddSSHOTP() return otp, nil } diff --git a/builtin/logical/ssh/path_issue_sign.go b/builtin/logical/ssh/path_issue_sign.go index f0a87e3dbf..baf5574d38 100644 --- a/builtin/logical/ssh/path_issue_sign.go +++ b/builtin/logical/ssh/path_issue_sign.go @@ -129,6 +129,8 @@ func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logic return nil, nil, errors.New("error marshaling signed certificate") } + b.sshCertificateCounter.Increment().AddSSHCertificate(ttl) + response := &logical.Response{ Data: map[string]interface{}{ "serial_number": strconv.FormatUint(certificate.Serial, 16), diff --git a/sdk/logical/certificate_counter.go b/sdk/logical/certificate_counter.go index 0504c06d4d..9f29f66d97 100644 --- a/sdk/logical/certificate_counter.go +++ b/sdk/logical/certificate_counter.go @@ -6,6 +6,7 @@ package logical import ( "crypto/x509" "math" + "time" ) // CertificateCounter is an interface for incrementing the count of issued and stored @@ -28,16 +29,20 @@ type CertCount struct { // purposes. Each certificate's billable units = (Validity Hours รท 730), rounded to 4 decimal // places. PkiDurationAdjustedCerts float64 + SSHIssuedCerts float64 + SSHIssuedOTPs uint64 } func (i *CertCount) Add(other CertCount) { i.IssuedCerts += other.IssuedCerts i.StoredCerts += other.StoredCerts i.PkiDurationAdjustedCerts += other.PkiDurationAdjustedCerts + i.SSHIssuedCerts += other.SSHIssuedCerts + i.SSHIssuedOTPs += other.SSHIssuedOTPs } func (i *CertCount) IsZero() bool { - return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0 + return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0 && i.SSHIssuedCerts == 0 && i.SSHIssuedOTPs == 0 } // durationAdjustedCertificateCount calculates the billable units for a certificate based on its @@ -58,6 +63,8 @@ func durationAdjustedCertificateCount(validitySeconds int64) float64 { type CertCountIncrementer interface { AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer + AddSSHCertificate(ttl time.Duration) CertCountIncrementer + AddSSHOTP() CertCountIncrementer } type certCountIncrementer struct { @@ -88,3 +95,21 @@ func (c *certCountIncrementer) AddIssuedCertificate(stored bool, cert *x509.Cert return c } + +func (c *certCountIncrementer) AddSSHCertificate(ttl time.Duration) CertCountIncrementer { + count := CertCount{ + SSHIssuedCerts: durationAdjustedCertificateCount(int64(ttl.Seconds())), + } + + c.counter.AddCount(count) + + return c +} + +func (c *certCountIncrementer) AddSSHOTP() CertCountIncrementer { + c.counter.AddCount(CertCount{ + SSHIssuedOTPs: 1, + }) + + return c +} diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index 1e43a9d9f1..1ee8cb0884 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -29,6 +29,8 @@ const ( KmipEnabledPrefix = "kmipEnabled/" PkiDurationAdjustedCountPrefix = "normalizedCertsIssued/" MetricsLastUpdatedAtPrefix = "metricsLastUpdatedAt/" + SSHCertificateMetric = "ssh/normalized-certs-issued" + SSHOTPMetric = "ssh/credential-count" BillingWriteInterval = 10 * time.Minute // pluginCountsSendTimeout is the timeout for sending plugin counts to the active node diff --git a/vault/consumption_billing_util.go b/vault/consumption_billing_util.go index 64c05b32b7..9e19de8947 100644 --- a/vault/consumption_billing_util.go +++ b/vault/consumption_billing_util.go @@ -11,6 +11,7 @@ import ( "time" "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/billing" ) @@ -758,3 +759,169 @@ func (c *Core) UpdateMetricsLastUpdateTime(ctx context.Context, currentMonth, up return c.storeMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, normalizedMonth, updateTime) } + +// GetStoredSSHDurationAdjustedCertCount retrieves the stored SSH duration-adjusted certificate count +// for the specified month. The count is stored as a float64. +// Returns 0 if no count has been stored for the given month. +func (c *Core) GetStoredSSHDurationAdjustedCertCount(ctx context.Context, currentMonth time.Time) (float64, error) { + c.consumptionBillingLock.RLock() + cb := c.consumptionBilling + c.consumptionBillingLock.RUnlock() + + if cb == nil { + return 0, errors.New("consumption billing is not initialized") + } + + cb.BillingStorageLock.RLock() + defer cb.BillingStorageLock.RUnlock() + + return c.getStoredSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth) +} + +func (c *Core) getStoredSSHDurationAdjustedCertCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (float64, error) { + billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHCertificateMetric) + + view, ok := c.GetBillingSubView() + if !ok { + return 0, errors.New("error reading SSH duration-adjusted count: billing subview not available") + } + + se, err := view.Get(ctx, billingPath) + if se == nil || err != nil { + return 0, err + } + + var certCount float64 + err = se.DecodeJSON(&certCount) + if err != nil { + return 0, fmt.Errorf("error decoding current SSH duration adjusted cert count: %w", err) + } + + return certCount, nil +} + +func (c *Core) UpdateStoredSSHDurationAdjustedCertCount(ctx context.Context, currentMonth time.Time, certCount float64) (float64, error) { + c.consumptionBillingLock.RLock() + cb := c.consumptionBilling + c.consumptionBillingLock.RUnlock() + + if cb == nil { + return 0, ErrConsumptionBillingNotInitialized + } + cb.BillingStorageLock.Lock() + defer cb.BillingStorageLock.Unlock() + storedCertCount, err := c.getStoredSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth) + if err != nil { + return 0, err + } + + err = c.storeSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth, certCount+storedCertCount) + if err != nil { + return 0, err + } + + return certCount, nil +} + +func (c *Core) storeSSHDurationAdjustedCertCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, certCount float64) error { + billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHCertificateMetric) + + countBytes, err := jsonutil.EncodeJSON(certCount) + if err != nil { + return err + } + + entry := &logical.StorageEntry{ + Key: billingPath, + Value: countBytes, + } + + view, ok := c.GetBillingSubView() + if !ok { + return nil + } + return view.Put(ctx, entry) +} + +// GetStoredSSHOTPCount retrieves the stored SSH OTP count +// for the specified month. The count is stored as a uint64. +// Returns 0 if no count has been stored for the given month. +func (c *Core) GetStoredSSHOTPCount(ctx context.Context, currentMonth time.Time) (uint64, error) { + c.consumptionBillingLock.RLock() + cb := c.consumptionBilling + c.consumptionBillingLock.RUnlock() + + if cb == nil { + return 0, errors.New("consumption billing is not initialized") + } + + cb.BillingStorageLock.RLock() + defer cb.BillingStorageLock.RUnlock() + + return c.getStoredSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth) +} + +func (c *Core) getStoredSSHOTPCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (uint64, error) { + billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHOTPMetric) + + view, ok := c.GetBillingSubView() + if !ok { + return 0, errors.New("error reading SSH OTP count: billing subview not available") + } + + se, err := view.Get(ctx, billingPath) + if se == nil || err != nil { + return 0, err + } + + var otpCount uint64 + err = se.DecodeJSON(&otpCount) + if err != nil { + return 0, fmt.Errorf("error decoding current OTP cert count: %w", err) + } + + return otpCount, nil +} + +func (c *Core) UpdateStoredSSHOTPCount(ctx context.Context, currentMonth time.Time, otpCount uint64) (uint64, error) { + c.consumptionBillingLock.RLock() + cb := c.consumptionBilling + c.consumptionBillingLock.RUnlock() + + if cb == nil { + return 0, ErrConsumptionBillingNotInitialized + } + cb.BillingStorageLock.Lock() + defer cb.BillingStorageLock.Unlock() + storedOTPCount, err := c.getStoredSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth) + if err != nil { + return 0, err + } + + err = c.storeSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth, otpCount+storedOTPCount) + if err != nil { + return 0, err + } + + return otpCount, nil +} + +func (c *Core) storeSSHOTPCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, otpCount uint64) error { + billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHOTPMetric) + + countBytes, err := jsonutil.EncodeJSON(otpCount) + if err != nil { + return err + } + + entry := &logical.StorageEntry{ + Key: billingPath, + Value: countBytes, + } + + view, ok := c.GetBillingSubView() + if !ok { + return nil + } + return view.Put(ctx, entry) +} diff --git a/vault/consumption_billing_util_test.go b/vault/consumption_billing_util_test.go index b0d8c60d40..6fa127d2b0 100644 --- a/vault/consumption_billing_util_test.go +++ b/vault/consumption_billing_util_test.go @@ -6,6 +6,7 @@ package vault import ( "context" "fmt" + "math" "testing" "time" @@ -23,6 +24,7 @@ import ( logicalDatabase "github.com/hashicorp/vault/builtin/logical/database" logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad" logicalRabbitMQ "github.com/hashicorp/vault/builtin/logical/rabbitmq" + "github.com/hashicorp/vault/builtin/logical/ssh" "github.com/hashicorp/vault/builtin/logical/totp" "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/helper/namespace" @@ -809,6 +811,265 @@ func TestTransitDataProtectionCallCounts(t *testing.T) { require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0") } +// TestSSHCertCounts tests that we correctly store and track the SSH certificate counts +func TestSSHCertCounts(t *testing.T) { + standardDuration := 730.0 + validityHours := float64(60*60*24) / 3600.0 + units := validityHours / standardDuration + // Round to 4 decimal places + expectedCertUnit := math.Round(units*10000) / 10000 + + t.Parallel() + coreConfig := &CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "ssh": ssh.Factory, + }, + } + + core, _, root := TestCoreUnsealedWithConfig(t, coreConfig) + + // Mount SSH backend + req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh") + req.Data["type"] = "ssh" + req.ClientToken = root + ctx := namespace.RootContext(context.Background()) + _, err := core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Create a certificate + req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca") + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test") + req.ClientToken = root + req.Data["key_type"] = "ca" + req.Data["allow_user_certificates"] = true + req.Data["allow_empty_principals"] = true + req.Data["ttl"] = "1d" + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test") + req.ClientToken = root + resp, err := core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + // Verify that the SSH counter is incremented + require.Equal(t, expectedCertUnit, core.certCountManager.GetCounts().SSHIssuedCerts) + + // Test sign endpoint + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/sign/test") + req.Data["public_key"] = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJBp4mozY/snvG/+pkgv4xYifIFB2ov3gAvAqXgFqNpj vault-enterprise-key" + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + // Verify that the SSH counter is incremented + require.Equal(t, expectedCertUnit*2, core.certCountManager.GetCounts().SSHIssuedCerts) + + // Now test persisting the summed counts - store and retrieve counts + // First, update the SSH cert counts (this will sum current counter with stored value) + currentCount := core.certCountManager.GetCounts().SSHIssuedCerts + core.certCountManager.StopConsumerJob() + + time.Sleep(20 * time.Millisecond) + + // Verify the counter was reset after update + require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update") + + // Retrieve the stored counts + storedCounts, err := core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now()) + require.NoError(t, err) + require.Equal(t, currentCount, storedCounts) + + core.certCountManager.StartConsumerJob(core.consumeCertCounts) + + // Perform more operations to increase the counter + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test") + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + // Counter should now be 1 cert + require.Equal(t, expectedCertUnit, core.certCountManager.GetCounts().SSHIssuedCerts) + + // Update counts again - should sum the new count with the stored count + core.certCountManager.StopConsumerJob() + time.Sleep(20 * time.Millisecond) + + // Verify the counter was reset after update + require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update") + + // Verify stored counts are now the sum + summedCounts, err := core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now()) + require.NoError(t, err) + + expectedSum := currentCount + expectedCertUnit + require.Equal(t, expectedSum, summedCounts, "Count should be sum of stored and current") + + core.certCountManager.StartConsumerJob(core.consumeCertCounts) + + // Add more operations without manually resetting + for i := 0; i < 3; i++ { + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/sign/test") + req.Data["public_key"] = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJBp4mozY/snvG/+pkgv4xYifIFB2ov3gAvAqXgFqNpj vault-enterprise-key" + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + } + + // Counter should be 3 certs + require.Equal(t, expectedCertUnit*3, core.certCountManager.GetCounts().SSHIssuedCerts) + + // Update counts - should sum 3 with the previous stored sum + core.certCountManager.StopConsumerJob() + time.Sleep(20 * time.Millisecond) + + // Verify the counter was reset after update + require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update") + + // Verify stored counts + storedCounts, err = core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now()) + require.NoError(t, err) + expectedSum += expectedCertUnit * 3 + require.Equal(t, expectedSum, storedCounts) +} + +// TestSSHOTPCounts tests that we correctly store and track the SSH OTP counts +func TestSSHOTPCounts(t *testing.T) { + t.Parallel() + coreConfig := &CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "ssh": ssh.Factory, + }, + } + + core, _, root := TestCoreUnsealedWithConfig(t, coreConfig) + + // Mount SSH backend + req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh") + req.Data["type"] = "ssh" + req.ClientToken = root + ctx := namespace.RootContext(context.Background()) + _, err := core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Create a certificate + req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca") + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test") + req.ClientToken = root + req.Data["key_type"] = "otp" + req.Data["default_user"] = "user" + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/zeroaddress") + req.ClientToken = root + req.Data["roles"] = "test" + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test") + req.ClientToken = root + req.Data["ip"] = "1.2.3.4" + resp, err := core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + // Verify that the SSH counter is incremented + require.Equal(t, uint64(1), core.certCountManager.GetCounts().SSHIssuedOTPs) + + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test") + req.ClientToken = root + req.Data["ip"] = "1.2.3.4" + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + // Verify that the SSH counter is incremented + require.Equal(t, uint64(2), core.certCountManager.GetCounts().SSHIssuedOTPs) + + // Now test persisting the summed counts - store and retrieve counts + // First, update the SSH cert counts (this will sum current counter with stored value) + currentCount := core.certCountManager.GetCounts().SSHIssuedOTPs + core.certCountManager.StopConsumerJob() + + time.Sleep(20 * time.Millisecond) + + // Verify the counter was reset after update + require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update") + + // Retrieve the stored counts + storedCounts, err := core.GetStoredSSHOTPCount(ctx, time.Now()) + require.NoError(t, err) + require.Equal(t, currentCount, storedCounts) + + core.certCountManager.StartConsumerJob(core.consumeCertCounts) + + // Perform more operations to increase the counter + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test") + req.ClientToken = root + req.Data["ip"] = "1.2.3.4" + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + // Counter should now be 1 + require.Equal(t, uint64(1), core.certCountManager.GetCounts().SSHIssuedOTPs) + + // Update counts again - should sum the new count with the stored count + core.certCountManager.StopConsumerJob() + time.Sleep(20 * time.Millisecond) + + // Verify the counter was reset after update + require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update") + + // Verify stored counts are now the sum + summedCounts, err := core.GetStoredSSHOTPCount(ctx, time.Now()) + require.NoError(t, err) + + expectedSum := currentCount + 1 + require.Equal(t, expectedSum, summedCounts, "Count should be sum of stored and current") + + core.certCountManager.StartConsumerJob(core.consumeCertCounts) + + // Add more operations without manually resetting + for i := 0; i < 3; i++ { + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test") + req.ClientToken = root + req.Data["ip"] = "1.2.3.4" + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + } + + // Counter should be 3 + require.Equal(t, uint64(3), core.certCountManager.GetCounts().SSHIssuedOTPs) + + // Update counts - should sum 3 with the previous stored sum + core.certCountManager.StopConsumerJob() + time.Sleep(20 * time.Millisecond) + + // Verify the counter was reset after update + require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update") + + // Verify stored counts + storedCounts, err = core.GetStoredSSHOTPCount(ctx, time.Now()) + require.NoError(t, err) + expectedSum += 3 + require.Equal(t, expectedSum, storedCounts) +} + func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) { raw, ok := core.router.root.Get(mount + "/") if !ok { diff --git a/vault/extended_system_view.go b/vault/extended_system_view.go index 12c20bb7e0..e9b1b2fda1 100644 --- a/vault/extended_system_view.go +++ b/vault/extended_system_view.go @@ -13,7 +13,10 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -var _ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil) +var ( + _ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil) + _ logical.CertificateCountSystemView = (*extendedSystemViewImpl)(nil) +) type extendedSystemViewImpl struct { dynamicSystemView @@ -152,3 +155,7 @@ func (e extendedSystemViewImpl) DeregisterWellKnownRedirect(ctx context.Context, func (e extendedSystemViewImpl) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) { return e.core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName) } + +func (e extendedSystemViewImpl) GetCertificateCounter() logical.CertificateCounter { + return e.core.GetCertificateCounter() +} diff --git a/vault/external_tests/api/sys_billing_test.go b/vault/external_tests/api/sys_billing_test.go index 11078b8db8..4400a3af92 100644 --- a/vault/external_tests/api/sys_billing_test.go +++ b/vault/external_tests/api/sys_billing_test.go @@ -172,6 +172,7 @@ func Test_BillingOverview_EmptyCluster(t *testing.T) { "data_protection_calls": false, "pki_units": false, "managed_keys": false, + "ssh_units": false, } for _, metric := range currentMonth.UsageMetrics { diff --git a/vault/logical_system_helpers.go b/vault/logical_system_helpers.go index 5b5dda1449..4ce29be3cb 100644 --- a/vault/logical_system_helpers.go +++ b/vault/logical_system_helpers.go @@ -332,6 +332,21 @@ func (c *Core) consumeCertCounts(inc logical.CertCount) { unconsumed.PkiDurationAdjustedCerts = 0 } + c.logger.Info("storing SSH counts", "sshDurationAdjustedCount", inc.SSHIssuedCerts, "sshOTPCount", inc.SSHIssuedOTPs) + _, err = c.UpdateStoredSSHDurationAdjustedCertCount(c.activeContext, time.Now(), inc.SSHIssuedCerts) + if err != nil { + c.logger.Error("error storing SSH duration adjusted certificate count", "error", err) + } else { + unconsumed.SSHIssuedCerts = 0 + } + + _, err = c.UpdateStoredSSHOTPCount(c.activeContext, time.Now(), inc.SSHIssuedOTPs) + if err != nil { + c.logger.Error("error storing SSH OTP count", "error", err) + } else { + unconsumed.SSHIssuedOTPs = 0 + } + default: c.logger.Error("Unexpected HA state when consuming certificate counts", "ha_state", haState) } diff --git a/vault/logical_system_use_case_billing.go b/vault/logical_system_use_case_billing.go index a39d6a9bc5..e532413173 100644 --- a/vault/logical_system_use_case_billing.go +++ b/vault/logical_system_use_case_billing.go @@ -210,6 +210,12 @@ func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Ti }, }) + sshCounts, err := b.buildSSHMetric(ctx, month) + if err != nil { + return nil, err + } + usageMetrics = append(usageMetrics, sshCounts) + dataUpdatedAt := b.Core.computeUpdatedAt(ctx, month, currentMonth) monthStr := month.Format("2006-01") @@ -486,3 +492,32 @@ func (c *Core) getThirdPartyPluginCounts(ctx context.Context, month time.Time) ( return thirdPartyPluginCounts, nil } + +func (b *SystemBackend) buildSSHMetric(ctx context.Context, month time.Time) (map[string]interface{}, error) { + certCounts, err := b.Core.GetStoredSSHDurationAdjustedCertCount(ctx, month) + if err != nil { + return nil, fmt.Errorf("error retrieving SSH duration-adjuested cert counts for current month: %w", err) + } + + otpCounts, err := b.Core.GetStoredSSHOTPCount(ctx, month) + if err != nil { + return nil, fmt.Errorf("error retrieving SSH OTP counts for current month: %w", err) + } + + return map[string]interface{}{ + "metric_name": "ssh_units", + "metric_data": map[string]interface{}{ + "total": certCounts + float64(otpCounts), + "metric_details": []map[string]interface{}{ + { + "type": "otp_units", + "count": otpCounts, + }, + { + "type": "certificate_units", + "count": certCounts, + }, + }, + }, + }, nil +} diff --git a/vault/logical_system_use_case_billing_test.go b/vault/logical_system_use_case_billing_test.go index bc5d2eceba..e5763605fd 100644 --- a/vault/logical_system_use_case_billing_test.go +++ b/vault/logical_system_use_case_billing_test.go @@ -11,6 +11,7 @@ import ( logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" logicalAws "github.com/hashicorp/vault/builtin/logical/aws" logicalDatabase "github.com/hashicorp/vault/builtin/logical/database" + logicalSsh "github.com/hashicorp/vault/builtin/logical/ssh" logicalTransit "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/pluginconsts" @@ -172,6 +173,7 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) { pluginconsts.SecretEngineAWS: logicalAws.Factory, pluginconsts.SecretEngineDatabase: logicalDatabase.Factory, pluginconsts.SecretEngineTransit: logicalTransit.Factory, + pluginconsts.SecretEngineSsh: logicalSsh.Factory, }, }) b := c.systemBackend @@ -233,6 +235,53 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) { _, err = c.HandleRequest(ctx, req) require.NoError(t, err) + // Create SSH certificate and OTP + req = logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh") + req.Data["type"] = "ssh" + req.ClientToken = root + resp, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca") + req.ClientToken = root + resp, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test-cert") + req.ClientToken = root + req.Data["key_type"] = "ca" + req.Data["allow_user_certificates"] = true + req.Data["allow_empty_principals"] = true + req.Data["ttl"] = "1d" + _, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test-cert") + req.ClientToken = root + resp, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test-otp") + req.ClientToken = root + req.Data["key_type"] = "otp" + req.Data["default_user"] = "user" + _, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/zeroaddress") + req.ClientToken = root + req.Data["roles"] = "test-otp" + _, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + + req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test-otp") + req.ClientToken = root + req.Data["ip"] = "1.2.3.4" + resp, err = c.HandleRequest(ctx, req) + require.NoError(t, err) + require.Nil(t, resp.Error()) + // Update all metrics currentMonth := time.Now() _, err = c.UpdateMaxKvCounts(ctx, billing.LocalPrefix, currentMonth) @@ -244,6 +293,12 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) { _, err = c.UpdateTransitCallCounts(ctx, currentMonth) require.NoError(t, err) + _, err = c.UpdateStoredSSHDurationAdjustedCertCount(ctx, currentMonth, c.certCountManager.GetCounts().SSHIssuedCerts) + require.NoError(t, err) + + _, err = c.UpdateStoredSSHOTPCount(ctx, currentMonth, c.certCountManager.GetCounts().SSHIssuedOTPs) + require.NoError(t, err) + // Make a request to the billing overview endpoint req = logical.TestRequest(t, logical.ReadOperation, "billing/overview") req.Data["refresh_data"] = true @@ -366,6 +421,24 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) { require.True(t, ok, "managed_keys total should be int") require.GreaterOrEqual(t, total, 0) require.Contains(t, metricData, "metric_details") + + case "ssh_units": + require.Contains(t, metricData, "total") + total, ok := metricData["total"].(float64) + require.True(t, ok, "ssh_units total should be float64") + require.GreaterOrEqual(t, total, float64(0)) + + require.Contains(t, metricData, "metric_details") + metricDetails, ok := metricData["metric_details"].([]map[string]interface{}) + require.True(t, ok, "metric_details should be []map[string]interface{}") + require.NotEmpty(t, metricDetails) + require.Equal(t, len(metricDetails), 2) + + require.Equal(t, metricDetails[0]["type"], "otp_units") + require.GreaterOrEqual(t, metricDetails[0]["count"], uint64(0)) + + require.Equal(t, metricDetails[1]["type"], "certificate_units") + require.GreaterOrEqual(t, metricDetails[1]["count"], float64(0)) } } @@ -469,6 +542,7 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) { "data_protection_calls": false, "pki_units": false, "managed_keys": false, + "ssh_units": false, } for _, metric := range usageMetrics { @@ -522,6 +596,10 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) { details, ok := metricData["metric_details"].([]map[string]interface{}) require.True(t, ok, "%s metric_details should be array", metricName) require.Empty(t, details, "%s metric_details should be empty when total is 0", metricName) + case "ssh_units": + total, ok := metricData["total"].(float64) + require.True(t, ok, "ssh_units total should be float64") + require.Equal(t, float64(0), total, "ssh_units total should be 0") } }