diff --git a/api/sys_billing_test.go b/api/sys_billing_test.go index e84f6be994..7b816b3831 100644 --- a/api/sys_billing_test.go +++ b/api/sys_billing_test.go @@ -33,32 +33,45 @@ func TestSys_BillingOverview(t *testing.T) { currentMonth := resp.Months[0] require.Equal(t, "2026-01", currentMonth.Month) require.Equal(t, "2026-01-14T10:49:00Z", currentMonth.UpdatedAt) - require.Len(t, currentMonth.UsageMetrics, 4) + require.Len(t, currentMonth.UsageMetrics, 8, "should have all 8 metrics") - // Verify static_secrets metric - staticSecretsMetric := currentMonth.UsageMetrics[0] - require.Equal(t, "static_secrets", staticSecretsMetric.MetricName) - require.NotNil(t, staticSecretsMetric.MetricData) + // Create a map to verify all expected metrics are present + metricsMap := make(map[string]UsageMetric) + for _, metric := range currentMonth.UsageMetrics { + metricsMap[metric.MetricName] = metric + } + + // Verify all expected metrics are present + expectedMetrics := []string{ + "static_secrets", + "dynamic_roles", + "auto_rotated_roles", + "kmip", + "external_plugins", + "data_protection_calls", + "pki_units", + "managed_keys", + } + + for _, metricName := range expectedMetrics { + metric, exists := metricsMap[metricName] + require.True(t, exists, "metric %s should be present", metricName) + require.NotNil(t, metric.MetricData, "metric_data should not be nil for %s", metricName) + } + + // Verify specific metric structures + staticSecretsMetric := metricsMap["static_secrets"] require.Contains(t, staticSecretsMetric.MetricData, "total") require.Contains(t, staticSecretsMetric.MetricData, "metric_details") - // Verify kmip metric - kmipMetric := currentMonth.UsageMetrics[1] - require.Equal(t, "kmip", kmipMetric.MetricName) - require.NotNil(t, kmipMetric.MetricData) + kmipMetric := metricsMap["kmip"] require.Contains(t, kmipMetric.MetricData, "used_in_month") require.Equal(t, true, kmipMetric.MetricData["used_in_month"]) - // Verify pki_units metric - pkiMetric := currentMonth.UsageMetrics[2] - require.Equal(t, "pki_units", pkiMetric.MetricName) - require.NotNil(t, pkiMetric.MetricData) + pkiMetric := metricsMap["pki_units"] require.Contains(t, pkiMetric.MetricData, "total") - // Verify managed_keys metric - managedKeysMetric := currentMonth.UsageMetrics[3] - require.Equal(t, "managed_keys", managedKeysMetric.MetricName) - require.NotNil(t, managedKeysMetric.MetricData) + managedKeysMetric := metricsMap["managed_keys"] require.Contains(t, managedKeysMetric.MetricData, "total") require.Contains(t, managedKeysMetric.MetricData, "metric_details") @@ -102,12 +115,70 @@ const billingOverviewResponse = `{ ] } }, + { + "metric_name": "dynamic_roles", + "metric_data": { + "total": 15, + "metric_details": [ + { + "type": "aws_dynamic", + "count": 5 + }, + { + "type": "azure_dynamic", + "count": 5 + }, + { + "type": "database_dynamic", + "count": 5 + } + ] + } + }, + { + "metric_name": "auto_rotated_roles", + "metric_data": { + "total": 10, + "metric_details": [ + { + "type": "aws_static", + "count": 5 + }, + { + "type": "azure_static", + "count": 5 + } + ] + } + }, { "metric_name": "kmip", "metric_data": { "used_in_month": true } }, + { + "metric_name": "external_plugins", + "metric_data": { + "total": 3 + } + }, + { + "metric_name": "data_protection_calls", + "metric_data": { + "total": 100, + "metric_details": [ + { + "type": "transit", + "count": 50 + }, + { + "type": "transform", + "count": 50 + } + ] + } + }, { "metric_name": "pki_units", "metric_data": { diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index 4c03a55a6a..1e43a9d9f1 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -28,6 +28,7 @@ const ( ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" KmipEnabledPrefix = "kmipEnabled/" PkiDurationAdjustedCountPrefix = "normalizedCertsIssued/" + MetricsLastUpdatedAtPrefix = "metricsLastUpdatedAt/" BillingWriteInterval = 10 * time.Minute // pluginCountsSendTimeout is the timeout for sending plugin counts to the active node @@ -49,11 +50,6 @@ type ConsumptionBilling struct { // KmipSeenEnabledThisMonth tracks whether KMIP has been enabled during the current billing month. // This is used to avoid scanning all mounts every 10 minutes for KMIP billing detection. KmipSeenEnabledThisMonth atomic.Bool - - // LastMetricsUpdate tracks when billing metrics were last updated, either by the background worker - // or by the billing endpoint API call. This timestamp is used by the billing overview endpoint to - // indicate data freshness. - LastMetricsUpdate atomic.Value } type BillingConfig struct { diff --git a/vault/consumption_billing.go b/vault/consumption_billing.go index e26417e9bd..931dccb452 100644 --- a/vault/consumption_billing.go +++ b/vault/consumption_billing.go @@ -33,6 +33,7 @@ func (c *Core) setupConsumptionBilling(ctx context.Context) error { Logger: logger, } c.consumptionBillingLock.Unlock() + c.postUnsealFuncs = append(c.postUnsealFuncs, func() { c.consumptionBillingMetricsWorker(ctx) // Start the perf standby plugin counts worker if this is a perf standby @@ -113,6 +114,8 @@ func (c *Core) HandleStartOfMonth(ctx context.Context, currentMonth time.Time) { if err := c.resetInMemoryBillingMetrics(); err != nil { c.logger.Error("error resetting in memory billing metrics", "error", err) } + // Reset the metrics last update time to zero time to indicate new month data hasn't been updated yet + c.UpdateMetricsLastUpdateTime(ctx, currentMonth, time.Time{}) } func (c *Core) deletePreviousMonthBillingMetrics(ctx context.Context, currentMonth time.Time) error { @@ -153,7 +156,8 @@ func (c *Core) resetInMemoryBillingMetrics() error { return nil } -func (c *Core) updateBillingMetrics(ctx context.Context, currentMonth time.Time) error { +// updateBillingMetricsLocked must be called with stateLock already held. +func (c *Core) updateBillingMetricsLocked(ctx context.Context, currentMonth time.Time) error { // Check if systemBarrierView is initialized c.mountsLock.RLock() initialized := c.systemBarrierView != nil @@ -162,11 +166,11 @@ func (c *Core) updateBillingMetrics(ctx context.Context, currentMonth time.Time) if !initialized { return nil } - if c.PerfStandby() { + if c.perfStandby { // We do not update billing metrics on performance standbys // Instead we send any in memory counts to the primary. This doesn't apply // to role counts, but will be used for other metrics - } else if standby, _ := c.Standby(); standby { + } else if c.standby { // Do nothing if we are a standby. All requests get forwarded anyway } else { // The active node will need to flush max role counts to storage @@ -180,11 +184,21 @@ func (c *Core) updateBillingMetrics(ctx context.Context, currentMonth time.Time) c.logger.Info("updated cluster data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) } - c.consumptionBilling.LastMetricsUpdate.Store(time.Now().UTC()) + // Store the last metrics update time. This is used to determine the freshness of the billing data. + // We store this on the active node only, since this is the node that updates the billing metrics. + // The standby nodes will replicate this value, so it will be available on all nodes, but we avoid + // having all nodes write to this value to avoid write conflicts. + c.UpdateMetricsLastUpdateTime(ctx, currentMonth, time.Now().UTC()) } return nil } +func (c *Core) updateBillingMetrics(ctx context.Context, currentMonth time.Time) error { + c.stateLock.RLock() + defer c.stateLock.RUnlock() + return c.updateBillingMetricsLocked(ctx, currentMonth) +} + func (c *Core) UpdateReplicatedHWMMetrics(ctx context.Context, currentMonth time.Time) error { _, _, err := c.UpdateMaxRoleAndManagedKeyCounts(ctx, billing.ReplicatedPrefix, currentMonth) if err != nil { diff --git a/vault/consumption_billing_util.go b/vault/consumption_billing_util.go index aa96e901dd..33ca3abd39 100644 --- a/vault/consumption_billing_util.go +++ b/vault/consumption_billing_util.go @@ -10,6 +10,7 @@ import ( "strconv" "time" + "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/billing" ) @@ -259,6 +260,14 @@ func (c *Core) UpdateMaxRoleAndManagedKeyCounts(ctx context.Context, localPathPr return nil, nil, err } + // Add nil checks before dereferencing + if currentRoleCounts == nil { + currentRoleCounts = &RoleCounts{} + } + if currentManagedKeyCounts == nil { + currentManagedKeyCounts = &ManagedKeyCounts{} + } + // get max role counts maxRoleCounts, err := c.updateMaxRoleCounts(ctx, currentRoleCounts, localPathPrefix, currentMonth) if err != nil { @@ -684,3 +693,77 @@ func (c *Core) storePkiDurationAdjustedCountLocked(ctx context.Context, localPat return nil } + +// storeMetricsLastUpdateTimeLocked must be called with BillingStorageLock held +func (c *Core) storeMetricsLastUpdateTimeLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, updateTime time.Time) error { + billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.MetricsLastUpdatedAtPrefix) + entry := &logical.StorageEntry{ + Key: billingPath, + Value: []byte(updateTime.Format(time.RFC3339)), + } + view, ok := c.GetBillingSubView() + if !ok { + return nil + } + return view.Put(ctx, entry) +} + +// getMetricsLastUpdateTimeLocked retrieves timestamp of the last billing metrics update for the given month. If the value does not exist, the 0 timestamp will be returned. +func (c *Core) getMetricsLastUpdateTimeLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (time.Time, error) { + billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.MetricsLastUpdatedAtPrefix) + view, ok := c.GetBillingSubView() + if !ok { + return time.Time{}, nil + } + entry, err := view.Get(ctx, billingPath) + if err != nil { + return time.Time{}, err + } + if entry == nil { + return time.Time{}, nil + } + updateTime, err := time.Parse(time.RFC3339, string(entry.Value)) + if err != nil { + return time.Time{}, err + } + return updateTime, nil +} + +func (c *Core) GetMetricsLastUpdateTime(ctx context.Context, currentMonth time.Time) (time.Time, error) { + c.consumptionBillingLock.RLock() + cb := c.consumptionBilling + c.consumptionBillingLock.RUnlock() + + if cb == nil { + return time.Time{}, ErrConsumptionBillingNotInitialized + } + + // Normalize month to UTC start-of-month to avoid timezone/midnight mismatches + normalizedMonth := timeutil.StartOfMonth(currentMonth.UTC()) + + cb.BillingStorageLock.RLock() + defer cb.BillingStorageLock.RUnlock() + return c.getMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, normalizedMonth) +} + +// UpdateMetricsLastUpdateTime updates the last update time for billing metrics for the given month, and returns the value that was stored. +// Note that this last metrics update time is per cluster. It does NOT de-duplicate across clusters. For that reason, +// we will always store the time at the "local" prefix. +func (c *Core) UpdateMetricsLastUpdateTime(ctx context.Context, currentMonth, updateTime time.Time) error { + c.consumptionBillingLock.RLock() + cb := c.consumptionBilling + c.consumptionBillingLock.RUnlock() + + if cb == nil { + return ErrConsumptionBillingNotInitialized + } + + // Normalize month to UTC start-of-month and ensure updateTime is in UTC + normalizedMonth := timeutil.StartOfMonth(currentMonth.UTC()) + updateTime = updateTime.UTC() + + cb.BillingStorageLock.Lock() + defer cb.BillingStorageLock.Unlock() + + return c.storeMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, normalizedMonth, updateTime) +} diff --git a/vault/consumption_billing_util_test.go b/vault/consumption_billing_util_test.go index 094b48cbcb..b0d8c60d40 100644 --- a/vault/consumption_billing_util_test.go +++ b/vault/consumption_billing_util_test.go @@ -399,6 +399,51 @@ func TestHWMKvSecretsCounts(t *testing.T) { require.Equal(t, 5, counts) } +// TestStoreAndGetMetricsLastUpdateTimeLocked tests the store/get helpers +// that operate under the BillingStorageLock +func TestStoreAndGetMetricsLastUpdateTimeLocked(t *testing.T) { + coreConfig := &CoreConfig{} + core, _, _ := TestCoreUnsealedWithConfig(t, coreConfig) + + ctx := namespace.RootContext(context.Background()) + month := time.Now() + updateTime := time.Now().UTC().Truncate(time.Second) + + // Acquire billing storage lock as required by the helper contract + core.consumptionBilling.BillingStorageLock.Lock() + defer core.consumptionBilling.BillingStorageLock.Unlock() + + // Store under local prefix and verify + err := core.storeMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, month, updateTime) + require.NoError(t, err) + + got, err := core.getMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, month) + require.NoError(t, err) + require.Equal(t, updateTime.Format(time.RFC3339), got.Format(time.RFC3339)) + + // Ensure other prefix returns zero when not set + gotReplicated, err := core.getMetricsLastUpdateTimeLocked(ctx, billing.ReplicatedPrefix, month) + require.NoError(t, err) + require.True(t, gotReplicated.IsZero(), "replicated prefix should have no stored timestamp") +} + +// TestUpdateAndGetMetricsLastUpdateTime tests the public Update/Get helpers for the metrics last update time +func TestUpdateAndGetMetricsLastUpdateTime(t *testing.T) { + coreConfig := &CoreConfig{} + core, _, _ := TestCoreUnsealedWithConfig(t, coreConfig) + + ctx := namespace.RootContext(context.Background()) + month := time.Now() + updateTime := time.Now().UTC().Truncate(time.Second) + + err := core.UpdateMetricsLastUpdateTime(ctx, month, updateTime) + require.NoError(t, err) + + got, err := core.GetMetricsLastUpdateTime(ctx, month) + require.NoError(t, err) + require.Equal(t, updateTime.Format(time.RFC3339), got.Format(time.RFC3339)) +} + // TestStoreAndGetMaxTotpKeyCounts verifies that we can store and retrieve the HWM totp key counts correctly func TestStoreAndGetMaxTotpKeyCounts(t *testing.T) { coreConfig := &CoreConfig{ diff --git a/vault/logical_system_use_case_billing.go b/vault/logical_system_use_case_billing.go index a877efc5ce..3a463f7821 100644 --- a/vault/logical_system_use_case_billing.go +++ b/vault/logical_system_use_case_billing.go @@ -15,7 +15,11 @@ import ( "github.com/hashicorp/vault/vault/billing" ) -const pkiDurationAjustedCountMetricName = "pki_units" +const ( + WarningRefreshIgnoredOnStandby = "refresh_data parameter is supported only on the active node. " + + "Since this parameter was set on a performance standby, the billing data was not refreshed " + + "and retrieved from storage without update." +) func (b *SystemBackend) useCaseConsumptionBillingPaths() []*framework.Path { return []*framework.Path{ @@ -64,6 +68,16 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic currentMonth := time.Now() previousMonth := timeutil.StartOfPreviousMonth(currentMonth) + warnings := make([]string, 0) + + // Check if this is a performance standby and if refreshData is true, + // and add a warning that refresh will be ignored in this case. + // We do not need to hold stateLock here since HandleRequest is already holding this lock. + if refreshData && b.Core.perfStandby { + warnings = append(warnings, WarningRefreshIgnoredOnStandby) + refreshData = false + } + // Refresh data only if explicitly requested and for current month currentMonthData, err := b.buildMonthBillingData(ctx, currentMonth, refreshData) if err != nil { @@ -83,34 +97,45 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic } return &logical.Response{ - Data: resp, + Data: resp, + Warnings: warnings, }, nil } // buildMonthBillingData constructs billing data for a specific month func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Time, refreshData bool) (map[string]interface{}, error) { + currentMonth := timeutil.StartOfMonth(time.Now().UTC()) + // Check if the billing metrics need to be refreshed. We're running + // under the core stateLock during request handling,so call the no-lock helper to + // avoid recursive locking. + if refreshData { + if err := b.Core.updateBillingMetricsLocked(ctx, currentMonth); err != nil { + return nil, fmt.Errorf("error refreshing billing metrics: %w", err) + } + } + // Retrieve all billing metrics - combinedRoleCounts, combinedManagedKeyCounts, err := b.Core.getRoleAndManagedKeyCounts(ctx, month, refreshData) + combinedRoleCounts, combinedManagedKeyCounts, err := b.Core.getRoleAndManagedKeyCounts(ctx, month) if err != nil { return nil, err } - combinedKvCounts, err := b.Core.getKvCounts(ctx, month, refreshData) + combinedKvCounts, err := b.Core.getKvCounts(ctx, month) if err != nil { return nil, err } - transitCounts, transformCounts, err := b.Core.getDataProtectionCounts(ctx, month, refreshData) + transitCounts, transformCounts, err := b.Core.getDataProtectionCounts(ctx, month) if err != nil { return nil, err } - kmipEnabled, err := b.Core.getKmipStatus(ctx, month, refreshData) + kmipEnabled, err := b.Core.getKmipStatus(ctx, month) if err != nil { return nil, err } - thirdPartyPluginCounts, err := b.Core.getThirdPartyPluginCounts(ctx, month, refreshData) + thirdPartyPluginCounts, err := b.Core.getThirdPartyPluginCounts(ctx, month) if err != nil { return nil, err } @@ -185,27 +210,7 @@ func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Ti }, }) - // Determine updated_at timestamp based on whether data was refreshed - var dataUpdatedAt time.Time - if refreshData { - // Data was just refreshed, use current time and update the stored timestamp - dataUpdatedAt = time.Now().UTC() - b.Core.consumptionBilling.LastMetricsUpdate.Store(dataUpdatedAt) - } else { - // Data was not refreshed, use the last time metrics were updated by the background worker - lastUpdate := b.Core.consumptionBilling.LastMetricsUpdate.Load() - if lastUpdate != nil { - if t, ok := lastUpdate.(time.Time); ok && !t.IsZero() { - dataUpdatedAt = t - } else { - // Fallback to end of month if timestamp not available - dataUpdatedAt = timeutil.StartOfMonth(month.AddDate(0, 1, 0)).Add(-time.Second).UTC() - } - } else { - // Fallback to end of month if timestamp not available - dataUpdatedAt = timeutil.StartOfMonth(month.AddDate(0, 1, 0)).Add(-time.Second).UTC() - } - } + dataUpdatedAt := b.Core.computeUpdatedAt(ctx, month, currentMonth) monthStr := month.Format("2006-01") @@ -216,6 +221,40 @@ func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Ti }, nil } +// computeUpdatedAt determines the appropriate updated_at timestamp for billing data +func (c *Core) computeUpdatedAt(ctx context.Context, month, currentMonth time.Time) time.Time { + var dataUpdatedAt time.Time + isCurrentMonth := timeutil.StartOfMonth(month).Equal(currentMonth) + if isCurrentMonth { + // Use the last time metrics were updated. If it is zero, it means the data has not + // been updated yet for the current month. + lastUpdate, err := c.GetMetricsLastUpdateTime(ctx, currentMonth) + if err != nil { + // Avoid logging raw error contents which may include sensitive information. + c.logger.Error("error retrieving last metrics update time") + return time.Time{} + } + dataUpdatedAt = lastUpdate + } else { + // Check presence of a stored metrics timestamp for the previous month. + // If present, return the canonical end-of-month for the requested + // `month`. The stored timestamp acts strictly as a + // presence indicator. + previousMonthStart := timeutil.StartOfPreviousMonth(currentMonth) + previousMonthTimestamp, err := c.GetMetricsLastUpdateTime(ctx, previousMonthStart) + + // The previous month has not been updated yet. + if err != nil || previousMonthTimestamp.IsZero() { + return time.Time{} + } + + // Use requested month's canonical end-of-month. + dataUpdatedAt = timeutil.EndOfMonth(month.UTC()) + } + + return dataUpdatedAt +} + // buildDynamicRolesMetric creates the dynamic_roles metric from role counts. func buildDynamicRolesMetric(counts *RoleCounts) map[string]interface{} { total := 0 @@ -342,7 +381,7 @@ func (b *SystemBackend) buildPkiBillingMetric(ctx context.Context, month time.Ti } return map[string]interface{}{ - "metric_name": pkiDurationAjustedCountMetricName, + "metric_name": "pki_units", "metric_data": map[string]interface{}{ "total": count, }, @@ -350,61 +389,38 @@ func (b *SystemBackend) buildPkiBillingMetric(ctx context.Context, month time.Ti } // getRoleCounts retrieves and combines role and managed key counts from replicated and local storage -func (c *Core) getRoleAndManagedKeyCounts(ctx context.Context, month time.Time, updateCounts bool) (*RoleCounts, *ManagedKeyCounts, error) { +func (c *Core) getRoleAndManagedKeyCounts(ctx context.Context, month time.Time) (*RoleCounts, *ManagedKeyCounts, error) { var replicatedRoleCounts *RoleCounts - var replicatedManagedKeyCounts *ManagedKeyCounts replicatedTotpHWMValue := 0 replicatedKmseHWMValue := 0 var err error if c.isPrimary() { - if updateCounts { - replicatedRoleCounts, replicatedManagedKeyCounts, err = c.UpdateMaxRoleAndManagedKeyCounts(ctx, billing.ReplicatedPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error updating replicated max role and managed key counts: %w", err) - } - replicatedTotpHWMValue = replicatedManagedKeyCounts.TotpKeys - replicatedKmseHWMValue = replicatedManagedKeyCounts.KmseKeys - } else { - replicatedRoleCounts, err = c.GetStoredHWMRoleCounts(ctx, billing.ReplicatedPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error retrieving replicated max role counts: %w", err) - } - replicatedTotpHWMValue, err = c.GetStoredHWMTotpCounts(ctx, billing.ReplicatedPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error retrieving replicated max managed key count: %w", err) - } - replicatedKmseHWMValue, err = c.GetStoredHWMKmseCounts(ctx, billing.ReplicatedPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error retrieving replicated max kmse key count: %w", err) - } + replicatedRoleCounts, err = c.GetStoredHWMRoleCounts(ctx, billing.ReplicatedPrefix, month) + if err != nil { + return nil, nil, fmt.Errorf("error retrieving replicated max role counts: %w", err) + } + replicatedTotpHWMValue, err = c.GetStoredHWMTotpCounts(ctx, billing.ReplicatedPrefix, month) + if err != nil { + return nil, nil, fmt.Errorf("error retrieving replicated max managed key count: %w", err) + } + replicatedKmseHWMValue, err = c.GetStoredHWMKmseCounts(ctx, billing.ReplicatedPrefix, month) + if err != nil { + return nil, nil, fmt.Errorf("error retrieving replicated max kmse key count: %w", err) } } - var localRoleCounts *RoleCounts - var localManagedKeyCounts *ManagedKeyCounts - localTotpHWMValue := 0 - localKmseHWMValue := 0 - if updateCounts { - localRoleCounts, localManagedKeyCounts, err = c.UpdateMaxRoleAndManagedKeyCounts(ctx, billing.LocalPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error updating local max role and managed key counts: %w", err) - } - localTotpHWMValue = localManagedKeyCounts.TotpKeys - localKmseHWMValue = localManagedKeyCounts.KmseKeys - } else { - localRoleCounts, err = c.GetStoredHWMRoleCounts(ctx, billing.LocalPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error retrieving local max role counts: %w", err) - } - localTotpHWMValue, err = c.GetStoredHWMTotpCounts(ctx, billing.LocalPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error retrieving local max totp key count: %w", err) - } - localKmseHWMValue, err = c.GetStoredHWMKmseCounts(ctx, billing.LocalPrefix, month) - if err != nil { - return nil, nil, fmt.Errorf("error retrieving local max kmse key count: %w", err) - } + localRoleCounts, err := c.GetStoredHWMRoleCounts(ctx, billing.LocalPrefix, month) + if err != nil { + return nil, nil, fmt.Errorf("error retrieving local max role counts: %w", err) + } + localTotpHWMValue, err := c.GetStoredHWMTotpCounts(ctx, billing.LocalPrefix, month) + if err != nil { + return nil, nil, fmt.Errorf("error retrieving local max totp key count: %w", err) + } + localKmseHWMValue, err := c.GetStoredHWMKmseCounts(ctx, billing.LocalPrefix, month) + if err != nil { + return nil, nil, fmt.Errorf("error retrieving local max kmse key count: %w", err) } combinedManagedKeyCounts := &ManagedKeyCounts{ @@ -416,35 +432,20 @@ func (c *Core) getRoleAndManagedKeyCounts(ctx context.Context, month time.Time, } // getKvCounts retrieves and combines KV secret counts from replicated and local storage -func (c *Core) getKvCounts(ctx context.Context, month time.Time, updateCounts bool) (int, error) { +func (c *Core) getKvCounts(ctx context.Context, month time.Time) (int, error) { var replicatedKvCounts int var err error if c.isPrimary() { - if updateCounts { - replicatedKvCounts, err = c.UpdateMaxKvCounts(ctx, billing.ReplicatedPrefix, month) - if err != nil { - return 0, fmt.Errorf("error updating replicated max kv counts: %w", err) - } - } else { - replicatedKvCounts, err = c.GetStoredHWMKvCounts(ctx, billing.ReplicatedPrefix, month) - if err != nil { - return 0, fmt.Errorf("error retrieving replicated max kv counts: %w", err) - } + replicatedKvCounts, err = c.GetStoredHWMKvCounts(ctx, billing.ReplicatedPrefix, month) + if err != nil { + return 0, fmt.Errorf("error retrieving replicated max kv counts: %w", err) } } - var localKvCounts int - if updateCounts { - localKvCounts, err = c.UpdateMaxKvCounts(ctx, billing.LocalPrefix, month) - if err != nil { - return 0, fmt.Errorf("error updating local max kv counts: %w", err) - } - } else { - localKvCounts, err = c.GetStoredHWMKvCounts(ctx, billing.LocalPrefix, month) - if err != nil { - return 0, fmt.Errorf("error retrieving local max kv counts: %w", err) - } + localKvCounts, err := c.GetStoredHWMKvCounts(ctx, billing.LocalPrefix, month) + if err != nil { + return 0, fmt.Errorf("error retrieving local max kv counts: %w", err) } return replicatedKvCounts + localKvCounts, nil @@ -453,68 +454,34 @@ func (c *Core) getKvCounts(ctx context.Context, month time.Time, updateCounts bo // getDataProtectionCounts retrieves Transit and Transform call counts // Data protection call counts are stored at local path only // Each cluster tracks its own total requests to avoid double counting -func (c *Core) getDataProtectionCounts(ctx context.Context, month time.Time, updateCounts bool) (uint64, uint64, error) { - var transitCounts, transformCounts uint64 - var err error - - if updateCounts { - transitCounts, err = c.UpdateTransitCallCounts(ctx, month) - if err != nil { - return 0, 0, fmt.Errorf("error updating local transit call counts: %w", err) - } - transformCounts, err = c.UpdateTransformCallCounts(ctx, month) - if err != nil { - return 0, 0, fmt.Errorf("error updating local transform call counts: %w", err) - } - } else { - transitCounts, err = c.GetStoredTransitCallCounts(ctx, month) - if err != nil { - return 0, 0, fmt.Errorf("error retrieving local transit call counts: %w", err) - } - transformCounts, err = c.GetStoredTransformCallCounts(ctx, month) - if err != nil { - return 0, 0, fmt.Errorf("error retrieving local transform call counts: %w", err) - } +func (c *Core) getDataProtectionCounts(ctx context.Context, month time.Time) (uint64, uint64, error) { + transitCounts, err := c.GetStoredTransitCallCounts(ctx, month) + if err != nil { + return 0, 0, fmt.Errorf("error retrieving local transit call counts: %w", err) + } + transformCounts, err := c.GetStoredTransformCallCounts(ctx, month) + if err != nil { + return 0, 0, fmt.Errorf("error retrieving local transform call counts: %w", err) } return transitCounts, transformCounts, nil } // getKmipStatus retrieves KMIP enabled status (always stored at local path) -func (c *Core) getKmipStatus(ctx context.Context, month time.Time, updateCounts bool) (bool, error) { - var kmipEnabled bool - var err error - - if updateCounts { - kmipEnabled, err = c.UpdateKmipEnabled(ctx, month) - if err != nil { - return false, fmt.Errorf("error updating KMIP enabled status: %w", err) - } - } else { - kmipEnabled, err = c.GetStoredKmipEnabled(ctx, month) - if err != nil { - return false, fmt.Errorf("error retrieving KMIP enabled status: %w", err) - } +func (c *Core) getKmipStatus(ctx context.Context, month time.Time) (bool, error) { + kmipEnabled, err := c.GetStoredKmipEnabled(ctx, month) + if err != nil { + return false, fmt.Errorf("error retrieving KMIP enabled status: %w", err) } return kmipEnabled, nil } // getThirdPartyPluginCounts retrieves third-party plugin counts (always stored at local path) -func (c *Core) getThirdPartyPluginCounts(ctx context.Context, month time.Time, updateCounts bool) (int, error) { - var thirdPartyPluginCounts int - var err error - - if updateCounts { - thirdPartyPluginCounts, err = c.UpdateMaxThirdPartyPluginCounts(ctx, month) - if err != nil { - return 0, fmt.Errorf("error updating third-party plugin counts: %w", err) - } - } else { - thirdPartyPluginCounts, err = c.GetStoredThirdPartyPluginCounts(ctx, month) - if err != nil { - return 0, fmt.Errorf("error retrieving third-party plugin counts: %w", err) - } +func (c *Core) getThirdPartyPluginCounts(ctx context.Context, month time.Time) (int, error) { + thirdPartyPluginCounts, err := c.GetStoredThirdPartyPluginCounts(ctx, month) + if err != nil { + return 0, fmt.Errorf("error retrieving third-party plugin counts: %w", err) } return thirdPartyPluginCounts, nil diff --git a/vault/logical_system_use_case_billing_pki_test.go b/vault/logical_system_use_case_billing_pki_test.go index 81748ed2ea..34aa98f741 100644 --- a/vault/logical_system_use_case_billing_pki_test.go +++ b/vault/logical_system_use_case_billing_pki_test.go @@ -90,8 +90,7 @@ func TestGeneratePkiBillingMetric(t *testing.T) { overview, err := backend.buildPkiBillingMetric(ctx, month) require.NoError(t, err) - // Verify it uses the constant pkiDurationAjustedCountMetricName - require.Equal(t, pkiDurationAjustedCountMetricName, overview["metric_name"]) + // Verify it uses the right metric name require.Equal(t, "pki_units", overview["metric_name"]) }) } diff --git a/vault/logical_system_use_case_billing_test.go b/vault/logical_system_use_case_billing_test.go index c0bec91b18..bc5d2eceba 100644 --- a/vault/logical_system_use_case_billing_test.go +++ b/vault/logical_system_use_case_billing_test.go @@ -396,6 +396,11 @@ func TestSystemBackend_BillingOverview_PreviousMonth(t *testing.T) { c.consumptionBilling.BillingStorageLock.Unlock() require.NoError(t, err) + // Store metrics last update timestamp for previous month so it's detected as having data + testUpdateTime := time.Date(previousMonth.Year(), previousMonth.Month(), 15, 12, 0, 0, 0, time.UTC) + err = c.UpdateMetricsLastUpdateTime(ctx, previousMonth, testUpdateTime) + require.NoError(t, err) + // Make a request to the billing overview endpoint req := logical.TestRequest(t, logical.ReadOperation, "billing/overview") resp, err := b.HandleRequest(ctx, req) @@ -422,8 +427,8 @@ func TestSystemBackend_BillingOverview_PreviousMonth(t *testing.T) { require.NoError(t, err) // The updated_at for previous month should be at the end of that month - expectedEndOfMonth := timeutil.StartOfMonth(previousMonth.AddDate(0, 1, 0)).Add(-time.Second) - require.WithinDuration(t, expectedEndOfMonth, parsedTime, time.Minute) + expectedEndOfMonth := timeutil.EndOfMonth(previousMonth).UTC() + require.Equal(t, expectedEndOfMonth, parsedTime) } // TestSystemBackend_BillingOverview_EmptyMetrics verifies that the billing overview @@ -508,12 +513,12 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) { case "pki_units": total, ok := metricData["total"].(float64) require.True(t, ok, "pki_units total should be float64") - require.Equal(t, float64(0), total, "data_protection_calls total should be 0") + require.Equal(t, float64(0), total, "pki units total should be 0") case "managed_keys": total, ok := metricData["total"].(int) require.True(t, ok, "managed_keys total should be float64") - require.Equal(t, int(0), total, "data_protection_calls total should be 0") + require.Equal(t, int(0), total, "managed keys total should be 0") details, ok := metricData["metric_details"].([]map[string]interface{}) require.True(t, ok, "%s metric_details should be array", metricName) require.Empty(t, details, "%s metric_details should be empty when total is 0", metricName) @@ -598,7 +603,7 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { c, b, _ := testCoreSystemBackend(t) ctx := namespace.RootContext(nil) - // First, call with refresh_data set to set the LastMetricsUpdate timestamp + // First, call with refresh_data set to set the metrics last update timestamp req := logical.TestRequest(t, logical.ReadOperation, "billing/overview") req.Data["refresh_data"] = true resp, err := b.HandleRequest(ctx, req) @@ -612,18 +617,29 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { currentMonth, ok := months[0].(map[string]interface{}) require.True(t, ok) - // Get the updated_at timestamp from the first call + previousMonth, ok := months[1].(map[string]interface{}) + require.True(t, ok) + + // Get the updated_at timestamp from the first call (current month) firstUpdatedAt, ok := currentMonth["updated_at"].(string) require.True(t, ok) firstTime, err := time.Parse(time.RFC3339, firstUpdatedAt) require.NoError(t, err) - // Verify LastMetricsUpdate was set - lastUpdate := c.consumptionBilling.LastMetricsUpdate.Load() - require.NotNil(t, lastUpdate, "LastMetricsUpdate should be set after refresh") - storedTime, ok := lastUpdate.(time.Time) + // Verify the metrics last update time was set + lastUpdate, err := c.GetMetricsLastUpdateTime(ctx, time.Now().UTC()) + require.NoError(t, err) + require.Equal(t, firstTime, lastUpdate, "stored timestamp should match response timestamp") + + // Verify previous month timestamp is zero time (no data stored for previous month) + prevMonthUpdatedAt, ok := previousMonth["updated_at"].(string) require.True(t, ok) - require.WithinDuration(t, firstTime, storedTime, time.Second, "stored timestamp should match response timestamp") + prevMonthTime, err := time.Parse(time.RFC3339, prevMonthUpdatedAt) + require.NoError(t, err) + + // Previous month should be zero time since we haven't stored any data for it + require.True(t, prevMonthTime.IsZero(), + "previous month updated_at should be zero time when no data is stored") // Wait a moment to ensure time difference time.Sleep(100 * time.Millisecond) @@ -642,17 +658,116 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { currentMonth, ok = months[0].(map[string]interface{}) require.True(t, ok) - // Get the updated_at timestamp from the second call + previousMonth, ok = months[1].(map[string]interface{}) + require.True(t, ok) + + // Get the updated_at timestamp from the second call (current month) secondUpdatedAt, ok := currentMonth["updated_at"].(string) require.True(t, ok) secondTime, err := time.Parse(time.RFC3339, secondUpdatedAt) require.NoError(t, err) // The timestamp should be the same as the first call because we didn't refresh the data - require.WithinDuration(t, firstTime, secondTime, time.Second, - "updated_at without refresh should use stored LastMetricsUpdate timestamp") + require.Equal(t, firstTime, secondTime, + "updated_at without refresh should use stored metrics last update timestamp") // Verify the timestamps are equal require.Equal(t, firstUpdatedAt, secondUpdatedAt, "updated_at without refresh should be identical to the stored timestamp") + + // Verify previous month timestamp remains the same (zero time) + secondPrevMonthUpdatedAt, ok := previousMonth["updated_at"].(string) + require.True(t, ok) + require.Equal(t, prevMonthUpdatedAt, secondPrevMonthUpdatedAt, + "previous month updated_at should remain zero time") +} + +// TestSystemBackend_BillingOverview_UpdatedAtTimestamp_NoStoredTimestamp tests the behavior +// when the metrics last update time is zero time (background worker hasn't run yet) +func TestSystemBackend_BillingOverview_UpdatedAtTimestamp_NoStoredTimestamp(t *testing.T) { + c, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + // Verify the metrics last update time is zero time initially + lastUpdate, err := c.GetMetricsLastUpdateTime(ctx, time.Now().UTC()) + require.NoError(t, err) + require.True(t, lastUpdate.IsZero(), "metrics last update time should be zero time initially") + + // Call without refresh_data when timestamp is zero + req := logical.TestRequest(t, logical.ReadOperation, "billing/overview") + req.Data["refresh_data"] = false + resp, err := b.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + months, ok := resp.Data["months"].([]interface{}) + require.True(t, ok) + require.Len(t, months, 2) + + currentMonth, ok := months[0].(map[string]interface{}) + require.True(t, ok) + + // Get the updated_at timestamp + updatedAt, ok := currentMonth["updated_at"].(string) + require.True(t, ok) + updatedTime, err := time.Parse(time.RFC3339, updatedAt) + require.NoError(t, err) + + // Verify it's zero time to indicate data hasn't been updated yet + require.True(t, updatedTime.IsZero(), + "updated_at should be zero time when the metrics last update time is zero") + + // Verify previous month is also zero time (no stored timestamp for previous month) + previousMonth, ok := months[1].(map[string]interface{}) + require.True(t, ok) + prevMonthUpdatedAt, ok := previousMonth["updated_at"].(string) + require.True(t, ok) + prevMonthTime, err := time.Parse(time.RFC3339, prevMonthUpdatedAt) + require.NoError(t, err) + + // Previous month should also be zero time since no timestamp is stored + require.True(t, prevMonthTime.IsZero(), + "previous month updated_at should be zero time when no stored timestamp exists") +} + +// TestSystemBackend_BillingOverview_PreviousMonth_WithError tests the behavior +// when retrieving the previous month's timestamp fails with an error. +// This ensures the endpoint gracefully handles storage errors by returning zero time. +func TestSystemBackend_BillingOverview_PreviousMonth_WithError(t *testing.T) { + c, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + // Store some data for previous month + previousMonth := timeutil.StartOfPreviousMonth(time.Now()) + + // Store counts but intentionally do NOT store the metrics last update timestamp + // This simulates a scenario where data exists but timestamp retrieval might fail + c.consumptionBilling.BillingStorageLock.Lock() + err := c.storeMaxKvCountsLocked(ctx, 5, "local/", previousMonth) + c.consumptionBilling.BillingStorageLock.Unlock() + require.NoError(t, err) + + // Make a request to the billing overview endpoint + req := logical.TestRequest(t, logical.ReadOperation, "billing/overview") + resp, err := b.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + months, ok := resp.Data["months"].([]interface{}) + require.True(t, ok) + require.Len(t, months, 2) + + // Check previous month data + previousMonthData, ok := months[1].(map[string]interface{}) + require.True(t, ok) + + // Verify updated_at is zero time when no timestamp is stored + updatedAt, ok := previousMonthData["updated_at"].(string) + require.True(t, ok) + parsedTime, err := time.Parse(time.RFC3339, updatedAt) + require.NoError(t, err) + + // Should be zero time since no timestamp was stored for previous month + require.True(t, parsedTime.IsZero(), + "previous month updated_at should be zero time when timestamp is not stored") }