From 0626b9d3696aac6548373ef1d27f96d07214de51 Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Thu, 21 May 2026 12:58:23 -0600 Subject: [PATCH 1/2] VAULT-43097: Handle Phantom JWT (#14005) (#14936) * VAULT-43097 Update registerAuth to re-use preexisting enterprise token entries * VAULT-43097 Remove logging of raw error-object * VAULT-43097 Linter correction * VAULT-43097 Fix deadlock issue * VAULT-43097 Spike refactor * VAULT-43097 Sanitize register-auth error logging * VAULT-43097 Correct off-by-one timestamp in test * VAULT-43097 Add new external test * VAULT-43097 Adjust logging per review * VAULT-43097 Use type-only logging for auth registration failures * VAULT-43097 Address PR review feedback * VAULT-43097 PR review nit feedback Co-authored-by: Jason Pilz --- vault/core.go | 4 -- vault/request_handling.go | 91 +++++++++++++++++++++------------------ 2 files changed, 48 insertions(+), 47 deletions(-) diff --git a/vault/core.go b/vault/core.go index 03dc5c90d4..9eb0444f6d 100644 --- a/vault/core.go +++ b/vault/core.go @@ -369,10 +369,6 @@ type Core struct { keepHALockOnStepDown *uint32 heldHALock physical.Lock - // enterpriseTokenGetAuthRegisterFunc is an optional per-core test seam for - // enterprise token auth registration lookup. - enterpriseTokenGetAuthRegisterFunc func(*Core) (RegisterAuthFunc, error) - // shutdownDoneCh is used to notify when core.Shutdown() completes. // core.Shutdown() is typically issued in a goroutine to allow Vault to // release the stateLock. This channel is marked atomic to prevent race diff --git a/vault/request_handling.go b/vault/request_handling.go index e5d27c9f75..28dad843e7 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -2711,6 +2711,49 @@ func (c *Core) buildMfaEnforcementResponse(eConfig *mfa.MFAEnforcementConfig, en return mfaAny, nil } +func (c *Core) registerAuthLeaseForToken(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth, role string) error { + // Populate the client token, accessor, and TTL + auth.ClientToken = te.ID + auth.Accessor = te.Accessor + auth.TTL = te.TTL + auth.Orphan = te.Parent == "" + + switch auth.TokenType { + case logical.TokenTypeBatch: + // Ensure it's not marked renewable since it isn't + auth.Renewable = false + case logical.TokenTypeService, logical.TokenTypeEnt: + if auth.TokenType == logical.TokenTypeEnt { + // Ensure it's not marked renewable since enterprise tokens are not renewable + auth.Renewable = false + } + // Register with the expiration manager + if err := c.expiration.RegisterAuth(ctx, te, auth, role); err != nil { + return err + } + if te.ExternalID != "" { + auth.ClientToken = te.ExternalID + } + // Successful login, remove any entry from userFailedLoginInfo map + // if it exists. This is done for service tokens only. + if auth.TokenType == logical.TokenTypeService && auth.Alias != nil { + loginUserInfoKey := FailedLoginUser{ + aliasName: auth.Alias.Name, + mountAccessor: auth.Alias.MountAccessor, + } + + // We don't need to try to delete the lockedUsers storage entry, since we're + // processing a login request. If a login attempt is allowed, it means the user is + // unlocked and we only add storage entry when the user gets locked. + if err := updateUserFailedLoginInfo(ctx, c, loginUserInfoKey, nil, true); err != nil { + return err + } + } + } + + return nil +} + // RegisterAuth uses a logical.Auth object to create a token entry in the token // store, and registers a corresponding token lease to the expiration manager. // role is the login role used as part of the creation of the token entry. If not @@ -2752,51 +2795,13 @@ func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path st c.logger.Error("failed to create token", "error", err) return possiblyWrapOverloadedError("failed to create token", err) } - - // Populate the client token, accessor, and TTL - auth.ClientToken = te.ID - auth.Accessor = te.Accessor - auth.TTL = te.TTL - auth.Orphan = te.Parent == "" - - switch auth.TokenType { - case logical.TokenTypeBatch: - // Ensure it's not marked renewable since it isn't - auth.Renewable = false - case logical.TokenTypeService: - // Register with the expiration manager - if err := c.expiration.RegisterAuth(ctx, &te, auth, role); err != nil { - if err := c.tokenStore.revokeOrphan(ctx, te.ID); err != nil { - c.logger.Warn("failed to clean up token lease during login request", "request_path", path, "error", err) - } - c.logger.Error("failed to register token lease during login request", "request_path", path, "error", err) - return possiblyWrapOverloadedError("failed to register token lease during login request", err) + if err := c.registerAuthLeaseForToken(ctx, &te, auth, role); err != nil { + if revokeErr := c.tokenStore.revokeOrphan(ctx, te.ID); revokeErr != nil { + c.logger.Warn("failed to clean up token lease during login request", "request_path", path, "error", revokeErr) } - if te.ExternalID != "" { - auth.ClientToken = te.ExternalID - } - // Successful login, remove any entry from userFailedLoginInfo map - // if it exists. This is done for service tokens (for oss) here. - // For ent it is taken care by registerAuth RPC calls. - if auth.Alias != nil { - loginUserInfoKey := FailedLoginUser{ - aliasName: auth.Alias.Name, - mountAccessor: auth.Alias.MountAccessor, - } - - // We don't need to try to delete the lockedUsers storage entry, since we're - // processing a login request. If a login attempt is allowed, it means the user is - // unlocked and we only add storage entry when the user gets locked. - err = updateUserFailedLoginInfo(ctx, c, loginUserInfoKey, nil, true) - if err != nil { - return err - } - } - case logical.TokenTypeEnt: - // Ensure it's not marked renewable since enterprise tokens are not renewable - auth.Renewable = false + c.logger.Error("failed to register token lease during login request", "request_path", path, "error", err) + return possiblyWrapOverloadedError("failed to register token lease during login request", err) } - return nil } From be244e8702231806c7edc5592f7406a11d2e73f8 Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Thu, 21 May 2026 13:14:50 -0600 Subject: [PATCH 2/2] VAULT-42829: Create a new billing/config endpoint to set retention months (#14785) (#14945) * spike * seperate overview and config into 2 paths, enable overview in admin too, restrict config to root * warn on error, do not log msg * rename default const * change storage approach to string-based pattern * rename method * fmt * linters * add all response types for the path * use get from sdk * add api client methods and tests * update unit test * add cleanup and deletion tests * fix comments, add custom retention test wwith start and end months * add namespace tests * consolidate validation into one check * fix comments * add changelog * fix test * fix subpath for config * add and use a new lock for get and update methods * add a warning for when retention is increased * Update vault/consumption_billing_test.go * remove redundant TestSystemBackend_BillingConfig_Persistence test * integrate warnings testing into another test and remove redundant test --------- Co-authored-by: Amir Aslamov Co-authored-by: Jenny Deng --- api/sys_billing.go | 68 ++++ api/sys_billing_test.go | 43 +++ changelog/_14785.txt | 3 + vault/billing/billing_counts.go | 11 +- vault/consumption_billing.go | 11 +- vault/consumption_billing_test.go | 229 +++++++++++-- vault/consumption_billing_util.go | 47 +++ vault/consumption_billing_util_test.go | 41 +++ vault/core.go | 3 + vault/external_tests/api/sys_billing_test.go | 123 ++++++- vault/logical_system_use_case_billing.go | 203 +++++++++--- vault/logical_system_use_case_billing_test.go | 305 ++++++++++++++++-- 12 files changed, 983 insertions(+), 104 deletions(-) create mode 100644 changelog/_14785.txt diff --git a/api/sys_billing.go b/api/sys_billing.go index e7c137cc80..57473d7135 100644 --- a/api/sys_billing.go +++ b/api/sys_billing.go @@ -68,3 +68,71 @@ type UsageMetric struct { MetricName string `json:"metric_name" mapstructure:"metric_name"` MetricData map[string]interface{} `json:"metric_data" mapstructure:"metric_data"` } + +// GetBillingConfig returns the current billing retention configuration. +func (c *Sys) GetBillingConfig() (*BillingConfigResponse, error) { + return c.GetBillingConfigWithContext(context.Background()) +} + +// GetBillingConfigWithContext returns the current billing retention configuration. +func (c *Sys) GetBillingConfigWithContext(ctx context.Context) (*BillingConfigResponse, error) { + ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) + defer cancelFunc() + + r := c.c.NewRequest(http.MethodGet, "/v1/sys/billing/config") + + resp, err := c.c.rawRequestWithContext(ctx, r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + secret, err := ParseSecret(resp.Body) + if err != nil { + return nil, err + } + if secret == nil || secret.Data == nil { + return nil, errors.New("data from server response is empty") + } + + var result BillingConfigResponse + err = mapstructure.Decode(secret.Data, &result) + if err != nil { + return nil, err + } + + return &result, nil +} + +// SetBillingConfig sets the billing retention configuration. +func (c *Sys) SetBillingConfig(retentionMonths int) error { + return c.SetBillingConfigWithContext(context.Background(), retentionMonths) +} + +// SetBillingConfigWithContext sets the billing retention configuration. +func (c *Sys) SetBillingConfigWithContext(ctx context.Context, retentionMonths int) error { + ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) + defer cancelFunc() + + body := map[string]interface{}{ + "retention_months": retentionMonths, + } + + r := c.c.NewRequest(http.MethodPost, "/v1/sys/billing/config") + if err := r.SetJSONBody(body); err != nil { + return err + } + + resp, err := c.c.rawRequestWithContext(ctx, r) + if err != nil { + return err + } + defer resp.Body.Close() + + return nil +} + +// BillingConfigResponse represents the response from the billing config endpoint. +type BillingConfigResponse struct { + RetentionMonths int `json:"retention_months" mapstructure:"retention_months"` +} diff --git a/api/sys_billing_test.go b/api/sys_billing_test.go index 615ccd237f..660769e898 100644 --- a/api/sys_billing_test.go +++ b/api/sys_billing_test.go @@ -250,3 +250,46 @@ const billingOverviewResponse = `{ "warnings": null, "auth": null }` + +// TestSys_BillingConfig tests the GetBillingConfig and SetBillingConfig API client methods +func TestSys_BillingConfig(t *testing.T) { + mockVaultServer := httptest.NewServer(http.HandlerFunc(mockVaultBillingConfigHandler)) + defer mockVaultServer.Close() + + // Create API client pointing to mock server + cfg := DefaultConfig() + cfg.Address = mockVaultServer.URL + client, err := NewClient(cfg) + require.NoError(t, err) + + // Test GetBillingConfig + resp, err := client.Sys().GetBillingConfig() + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 37, resp.RetentionMonths) + + // Test SetBillingConfig + err = client.Sys().SetBillingConfig(48) + require.NoError(t, err) +} + +func mockVaultBillingConfigHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = w.Write([]byte(billingConfigResponse)) + } else if r.Method == http.MethodPost { + w.WriteHeader(http.StatusNoContent) + } +} + +const billingConfigResponse = `{ + "request_id": "a1b2c3d4-e5f6-7a8b-9c0d-1e2f3a4b5c6d", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "retention_months": 37 + }, + "wrap_info": null, + "warnings": null, + "auth": null +}` diff --git a/changelog/_14785.txt b/changelog/_14785.txt new file mode 100644 index 0000000000..3b847da5c0 --- /dev/null +++ b/changelog/_14785.txt @@ -0,0 +1,3 @@ +```release-note:improvement +consumption-billing: Add a new `sys/billing/config` endpoint to allow configuration of billing data retention (min 13 months, max 6 years). +``` \ No newline at end of file diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index 6721cc505b..09312e49ac 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -17,11 +17,18 @@ import ( ) const ( - // BillingRetentionMonths is the number of months of billing data to retain. + // DefaultBillingRetentionMonths is the default number of months of billing data to retain. // This includes the current month plus previous months (e.g., 37 = current + 36 previous months). - BillingRetentionMonths = 37 + DefaultBillingRetentionMonths = 37 + + // MinBillingRetentionMonths is the minimum allowed retention period (13 months = 1 year + current month) + MinBillingRetentionMonths = 13 + + // MaxBillingRetentionMonths is the maximum allowed retention period (72 months = 6 years) + MaxBillingRetentionMonths = 72 BillingSubPath = "billing/" + BillingConfigPath = "config" ReplicatedPrefix = "replicated/" RoleHWMCountsHWM = "maxRoleCounts/" TotpHWMCountsHWM = "maxTotpCounts/" diff --git a/vault/consumption_billing.go b/vault/consumption_billing.go index 439381b51d..92e220348d 100644 --- a/vault/consumption_billing.go +++ b/vault/consumption_billing.go @@ -131,8 +131,15 @@ func (c *Core) HandleStartOfMonth(ctx context.Context, currentMonth time.Time) { } func (c *Core) deleteExpiredBillingMetrics(ctx context.Context, currentMonth time.Time) error { - // Delete data from BillingRetentionMonths ago (keeping current month + previous (BillingRetentionMonths - 1) months = BillingRetentionMonths total) - monthToDelete := timeutil.StartOfMonth(currentMonth).AddDate(0, -billing.BillingRetentionMonths, 0) + // Get the configured retention period + retentionMonths, err := c.GetBillingRetentionMonths(ctx) + if err != nil { + c.logger.Warn("failed to get billing retention configuration, using default") + retentionMonths = billing.DefaultBillingRetentionMonths + } + + // Delete data from retentionMonths ago (keeping current month + previous (retentionMonths - 1) months = retentionMonths total) + monthToDelete := timeutil.StartOfMonth(currentMonth).AddDate(0, -retentionMonths, 0) // Delete billing metrics from both replicated and local prefixes for _, pathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { // If we are not the primary, then do not delete replicate metrics diff --git a/vault/consumption_billing_test.go b/vault/consumption_billing_test.go index d469659b96..773d2a8eb2 100644 --- a/vault/consumption_billing_test.go +++ b/vault/consumption_billing_test.go @@ -154,7 +154,7 @@ func TestConsumptionBillingMetricsWorker(t *testing.T) { } // TestHandleEndOfMonthMetrics tests that HandleEndOfMonth cleans up -// billing metrics from billing.BillingRetentionMonths ago (keeping billing.BillingRetentionMonths of data) and resets the in memory billing metrics +// billing metrics from billing.DefaultBillingRetentionMonths ago (keeping billing.DefaultBillingRetentionMonths of data) and resets the in memory billing metrics func TestHandleEndOfMonthMetrics(t *testing.T) { coreConfig := &CoreConfig{ LogicalBackends: roleLogicalBackends, @@ -163,11 +163,11 @@ func TestHandleEndOfMonthMetrics(t *testing.T) { }, } core, _, _ := TestCoreUnsealedWithConfig(t, coreConfig) - // Add some billing metrics to storage for (billing.BillingRetentionMonths - 1) and billing.BillingRetentionMonths months ago + // Add some billing metrics to storage for (billing.DefaultBillingRetentionMonths - 1) and billing.DefaultBillingRetentionMonths months ago // Use the util functions directly to avoid the need to mount the logical backends now := time.Now().UTC() - oldestRetainedMonth := timeutil.StartOfMonth(now).AddDate(0, -(billing.BillingRetentionMonths - 1), 0) - monthToDelete := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths, 0) + oldestRetainedMonth := timeutil.StartOfMonth(now).AddDate(0, -(billing.DefaultBillingRetentionMonths - 1), 0) + monthToDelete := timeutil.StartOfMonth(now).AddDate(0, -billing.DefaultBillingRetentionMonths, 0) for _, month := range []time.Time{monthToDelete, oldestRetainedMonth} { for _, localPathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { @@ -209,14 +209,14 @@ func TestHandleEndOfMonthMetrics(t *testing.T) { core.HandleStartOfMonth(context.Background(), now) for _, localPathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { - // billing.BillingRetentionMonths ago should have no billing metrics (deleted) + // billing.DefaultBillingRetentionMonths ago should have no billing metrics (deleted) view, ok := core.GetBillingSubView() require.True(t, ok) paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(localPathPrefix, monthToDelete)) require.NoError(t, err) - require.Equal(t, 0, len(paths), "data from billing.BillingRetentionMonths ago should be deleted") + require.Equal(t, 0, len(paths), "data from billing.DefaultBillingRetentionMonths ago should be deleted") - // (billing.BillingRetentionMonths - 1) months ago should still have the billing metrics (kept) + // (billing.DefaultBillingRetentionMonths - 1) months ago should still have the billing metrics (kept) view, ok = core.GetBillingSubView() require.True(t, ok) paths, err = view.List(context.Background(), billing.GetMonthlyBillingPath(localPathPrefix, oldestRetainedMonth)) @@ -236,8 +236,8 @@ func TestHandleEndOfMonthMetrics(t *testing.T) { } // TestDeleteExpiredBillingMetrics specifically tests the deleteExpiredBillingMetrics method -// to ensure it correctly deletes data from billing.BillingRetentionMonths ago while keeping -// data from (billing.BillingRetentionMonths - 1) months ago. +// to ensure it correctly deletes data from billing.DefaultBillingRetentionMonths ago while keeping +// data from (billing.DefaultBillingRetentionMonths - 1) months ago. func TestDeleteExpiredBillingMetrics(t *testing.T) { coreConfig := &CoreConfig{ LogicalBackends: roleLogicalBackends, @@ -246,8 +246,8 @@ func TestDeleteExpiredBillingMetrics(t *testing.T) { now := time.Now().UTC() currentMonth := timeutil.StartOfMonth(now) - oldestRetainedMonth := currentMonth.AddDate(0, -(billing.BillingRetentionMonths - 1), 0) - monthToDelete := currentMonth.AddDate(0, -billing.BillingRetentionMonths, 0) + oldestRetainedMonth := currentMonth.AddDate(0, -(billing.DefaultBillingRetentionMonths - 1), 0) + monthToDelete := currentMonth.AddDate(0, -billing.DefaultBillingRetentionMonths, 0) // Write billing data for multiple months including the month to be deleted and the oldest retained month for _, month := range []time.Time{monthToDelete, oldestRetainedMonth, currentMonth} { @@ -316,7 +316,7 @@ func TestDeleteExpiredBillingMetrics(t *testing.T) { // Month to delete should have no data paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, monthToDelete)) require.NoError(t, err) - require.Equal(t, 0, len(paths), "data from billing.BillingRetentionMonths ago should be deleted") + require.Equal(t, 0, len(paths), "data from billing.DefaultBillingRetentionMonths ago should be deleted") // Verify SSH metrics are deleted (they use subdirectory paths) sshCertPath := billing.GetMonthlyBillingMetricPath(pathPrefix, monthToDelete, billing.SSHCertificateMetric) @@ -332,7 +332,7 @@ func TestDeleteExpiredBillingMetrics(t *testing.T) { // Oldest retained month should still have data paths, err = view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, oldestRetainedMonth)) require.NoError(t, err) - require.Greater(t, len(paths), 0, "data from (billing.BillingRetentionMonths - 1) months ago should be kept") + require.Greater(t, len(paths), 0, "data from (billing.DefaultBillingRetentionMonths - 1) months ago should be kept") // Verify SSH metrics are kept for oldest retained month sshCertPath = billing.GetMonthlyBillingMetricPath(pathPrefix, oldestRetainedMonth, billing.SSHCertificateMetric) @@ -368,7 +368,7 @@ func TestDeleteExpiredBillingMetrics(t *testing.T) { require.False(t, currentTimestamp.IsZero(), "timestamp for current month should exist") } -// TestConsumptionBillingMetricsWorkerWithCustomClock tests that we correctly delete data older than billing.BillingRetentionMonths +// TestConsumptionBillingMetricsWorkerWithCustomClock tests that we correctly delete data older than billing.DefaultBillingRetentionMonths // and reset the in memory billing metrics when the clock is overridden for testing purposes func TestConsumptionBillingMetricsWorkerWithCustomClock(t *testing.T) { // 10 seconds until a new month (leave buffer for require.Eventually timeout) @@ -381,14 +381,14 @@ func TestConsumptionBillingMetricsWorkerWithCustomClock(t *testing.T) { } core, _, _ := TestCoreUnsealedWithConfig(t, coreConfig) - // Add some billing metrics to storage for (billing.BillingRetentionMonths - 1) and billing.BillingRetentionMonths months ago + // Add some billing metrics to storage for (billing.DefaultBillingRetentionMonths - 1) and billing.DefaultBillingRetentionMonths months ago // Use the util functions directly to avoid the need to mount the logical backends // The worker's "end of month" path calls HandleEndOfMonth with the *current* month, // which will be the next month once we cross the boundary. So the months should be // calculated relative to that boundary. currentMonthAtBoundary := timeutil.StartOfNextMonth(now) - oldestRetainedMonth := timeutil.StartOfMonth(currentMonthAtBoundary).AddDate(0, -(billing.BillingRetentionMonths - 1), 0) - monthToDelete := timeutil.StartOfMonth(currentMonthAtBoundary).AddDate(0, -billing.BillingRetentionMonths, 0) + oldestRetainedMonth := timeutil.StartOfMonth(currentMonthAtBoundary).AddDate(0, -(billing.DefaultBillingRetentionMonths - 1), 0) + monthToDelete := timeutil.StartOfMonth(currentMonthAtBoundary).AddDate(0, -billing.DefaultBillingRetentionMonths, 0) view, ok := core.GetBillingSubView() require.True(t, ok) roleCounts := &RoleCounts{ @@ -448,13 +448,13 @@ func TestConsumptionBillingMetricsWorkerWithCustomClock(t *testing.T) { } for _, localPathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { - // billing.BillingRetentionMonths ago should eventually have no billing metrics (deleted) + // billing.DefaultBillingRetentionMonths ago should eventually have no billing metrics (deleted) require.Eventually(t, func() bool { paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(localPathPrefix, monthToDelete)) return err == nil && len(paths) == 0 }, 20*time.Second, 100*time.Millisecond) - // All values from billing.BillingRetentionMonths ago should be 0 + // All values from billing.DefaultBillingRetentionMonths ago should be 0 maxRoleCounts, _ := core.GetStoredHWMRoleCounts(context.Background(), localPathPrefix, monthToDelete) require.Equal(t, &RoleCounts{}, maxRoleCounts) kvCounts, _ := core.GetStoredHWMKvCounts(context.Background(), localPathPrefix, monthToDelete) @@ -468,7 +468,7 @@ func TestConsumptionBillingMetricsWorkerWithCustomClock(t *testing.T) { require.Equal(t, 0, thirdPartyPluginCounts) } - // (billing.BillingRetentionMonths - 1) months ago should still have the billing metrics (kept) + // (billing.DefaultBillingRetentionMonths - 1) months ago should still have the billing metrics (kept) verifyMonthlyBillingMetrics(oldestRetainedMonth, localPathPrefix) } @@ -477,3 +477,192 @@ func TestConsumptionBillingMetricsWorkerWithCustomClock(t *testing.T) { require.Equal(t, uint64(0), core.GetInMemoryGcpKmsDataProtectionCallCounts()) require.False(t, core.consumptionBilling.KmipSeenEnabledThisMonth.Load()) } + +// TestDeleteExpiredBillingMetrics_CustomRetention tests that deleteExpiredBillingMetrics +// respects custom retention configuration. It verifies that when a custom retention period +// is set (e.g., 13 months), data is deleted according to that configuration rather than +// the default 37 months. +func TestDeleteExpiredBillingMetrics_CustomRetention(t *testing.T) { + coreConfig := &CoreConfig{ + LogicalBackends: roleLogicalBackends, + } + core, _, _ := TestCoreUnsealedWithConfig(t, coreConfig) + ctx := namespace.RootContext(context.Background()) + + // Set custom retention to minimum (13 months) + customRetention := billing.MinBillingRetentionMonths + err := core.UpdateBillingRetentionMonths(ctx, customRetention) + require.NoError(t, err) + + // Verify the custom retention was set + retentionMonths, err := core.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, customRetention, retentionMonths) + + now := time.Now().UTC() + currentMonth := timeutil.StartOfMonth(now) + + // With 13 months retention: + // - Month 12 months ago (index 12) should be kept (oldest retained) + // - Month 13 months ago (index 13) should be deleted + oldestRetainedMonth := currentMonth.AddDate(0, -(customRetention - 1), 0) + monthToDelete := currentMonth.AddDate(0, -customRetention, 0) + + // Write billing data for multiple months + for _, month := range []time.Time{monthToDelete, oldestRetainedMonth, currentMonth} { + for _, pathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { + core.storeMaxRoleCountsLocked(context.Background(), &RoleCounts{ + AWSDynamicRoles: 5, + AWSStaticRoles: 10, + LDAPDynamicRoles: 3, + OSLocalAccountRoles: 7, + }, pathPrefix, month) + core.storeMaxKvCountsLocked(context.Background(), 20, pathPrefix, month) + core.storeTransitCallCountsLocked(context.Background(), 15, pathPrefix, month) + core.storeSSHDurationAdjustedCertCountLocked(context.Background(), pathPrefix, month, 10.5) + core.storeSSHOTPCountLocked(context.Background(), pathPrefix, month, 25.0) + } + // Store updatedAtTimestamp for each month + testUpdateTime := time.Date(month.Year(), month.Month(), 15, 12, 0, 0, 0, time.UTC) + err := core.UpdateMetricsLastUpdateTime(context.Background(), month, testUpdateTime) + require.NoError(t, err) + } + + // Verify data exists before deletion + for _, pathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { + view, ok := core.GetBillingSubView() + require.True(t, ok) + + // Check month to be deleted has data + paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, monthToDelete)) + require.NoError(t, err) + require.Greater(t, len(paths), 0, "month to delete should have data before deletion") + + // Check oldest retained month has data + paths, err = view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, oldestRetainedMonth)) + require.NoError(t, err) + require.Greater(t, len(paths), 0, "oldest retained month should have data") + } + + // Call deleteExpiredBillingMetrics - it should use the custom retention + err = core.deleteExpiredBillingMetrics(context.Background(), currentMonth) + require.NoError(t, err) + + // Verify deletion results with custom retention + for _, pathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { + view, ok := core.GetBillingSubView() + require.True(t, ok) + + // Month to delete (13 months ago with custom retention) should have no data + paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, monthToDelete)) + require.NoError(t, err) + require.Equal(t, 0, len(paths), "data from %d months ago should be deleted with custom retention", customRetention) + + // Oldest retained month (12 months ago) should still have data + paths, err = view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, oldestRetainedMonth)) + require.NoError(t, err) + require.Greater(t, len(paths), 0, "data from %d months ago should be kept with custom retention", customRetention-1) + + // Current month should still have data + paths, err = view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, currentMonth)) + require.NoError(t, err) + require.Greater(t, len(paths), 0, "current month data should be kept") + } + + // Verify updatedAtTimestamp deletion with custom retention + deletedTimestamp, err := core.GetMetricsLastUpdateTime(context.Background(), monthToDelete) + require.NoError(t, err) + require.True(t, deletedTimestamp.IsZero(), "timestamp for deleted month should be zero") + + oldestTimestamp, err := core.GetMetricsLastUpdateTime(context.Background(), oldestRetainedMonth) + require.NoError(t, err) + require.False(t, oldestTimestamp.IsZero(), "timestamp for oldest retained month should exist") + + currentTimestamp, err := core.GetMetricsLastUpdateTime(context.Background(), currentMonth) + require.NoError(t, err) + require.False(t, currentTimestamp.IsZero(), "timestamp for current month should exist") +} + +// TestHandleStartOfMonth_CustomRetention tests that HandleStartOfMonth respects +// custom retention configuration when deleting expired billing metrics. +func TestHandleStartOfMonth_CustomRetention(t *testing.T) { + coreConfig := &CoreConfig{ + LogicalBackends: roleLogicalBackends, + BillingConfig: billing.BillingConfig{ + MetricsUpdateCadence: 3 * time.Second, + }, + } + core, _, _ := TestCoreUnsealedWithConfig(t, coreConfig) + ctx := namespace.RootContext(context.Background()) + + // Set custom retention to 20 months + customRetention := 20 + err := core.UpdateBillingRetentionMonths(ctx, customRetention) + require.NoError(t, err) + + // Add billing metrics for months based on custom retention + now := time.Now().UTC() + oldestRetainedMonth := timeutil.StartOfMonth(now).AddDate(0, -(customRetention - 1), 0) + monthToDelete := timeutil.StartOfMonth(now).AddDate(0, -customRetention, 0) + + for _, month := range []time.Time{monthToDelete, oldestRetainedMonth} { + for _, pathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { + core.storeMaxRoleCountsLocked(context.Background(), &RoleCounts{ + AWSDynamicRoles: 10, + AWSStaticRoles: 15, + LDAPDynamicRoles: 8, + GCPRolesets: 3, + DatabaseDynamicRoles: 5, + DatabaseStaticRoles: 7, + OSLocalAccountRoles: 9, + }, pathPrefix, month) + core.storeMaxKvCountsLocked(context.Background(), 10, pathPrefix, month) + + if pathPrefix == billing.LocalPrefix { + core.storeTransitCallCountsLocked(context.Background(), 10, pathPrefix, month) + core.storeGcpKmsCallCountsLocked(context.Background(), 10, pathPrefix, month) + core.storeThirdPartyPluginCountsLocked(context.Background(), pathPrefix, month, 10) + core.storeOidcDurationAdjustedCountLocked(context.Background(), month, 10) + core.storeSSHOTPCountLocked(context.Background(), pathPrefix, month, 10) + } + + // Verify data was stored + view, ok := core.GetBillingSubView() + require.True(t, ok) + paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, month)) + require.NoError(t, err) + expectedPaths := 2 // ReplicatedPrefix has roles and kv + if pathPrefix == billing.LocalPrefix { + expectedPaths = 7 // LocalPrefix has roles, kv, transit, gcp kms, third-party plugins, ssh and OIDC + } + require.Equal(t, expectedPaths, len(paths)) + } + } + + // Handle the start of the month - should delete based on custom retention + core.HandleStartOfMonth(context.Background(), now) + + for _, pathPrefix := range []string{billing.ReplicatedPrefix, billing.LocalPrefix} { + // Month to delete (customRetention months ago) should have no billing metrics + view, ok := core.GetBillingSubView() + require.True(t, ok) + paths, err := view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, monthToDelete)) + require.NoError(t, err) + require.Equal(t, 0, len(paths), "data from %d months ago should be deleted with custom retention", customRetention) + + // Oldest retained month should still have the billing metrics + paths, err = view.List(context.Background(), billing.GetMonthlyBillingPath(pathPrefix, oldestRetainedMonth)) + require.NoError(t, err) + expectedPaths := 2 // ReplicatedPrefix has roles and kv + if pathPrefix == billing.LocalPrefix { + expectedPaths = 7 // LocalPrefix has roles, kv, transit, gcp kms, third-party plugins, ssh and OIDC + } + require.Equal(t, expectedPaths, len(paths), "data from %d months ago should be kept with custom retention", customRetention-1) + } + + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts()) + require.Equal(t, uint64(0), core.GetInMemoryTransformDataProtectionCallCounts()) + require.Equal(t, uint64(0), core.GetInMemoryGcpKmsDataProtectionCallCounts()) + require.Equal(t, float64(0), core.GetInMemoryOidcCounts()) + require.False(t, core.consumptionBilling.KmipSeenEnabledThisMonth.Load()) +} diff --git a/vault/consumption_billing_util.go b/vault/consumption_billing_util.go index 9b4a5ab5fc..feb15c3836 100644 --- a/vault/consumption_billing_util.go +++ b/vault/consumption_billing_util.go @@ -456,6 +456,53 @@ func (c *Core) GetBillingSubView() (*BarrierView, bool) { return c.consumptionBillingSubView, true } +func (c *Core) GetBillingRetentionMonths(ctx context.Context) (int, error) { + c.billingConfigLock.RLock() + defer c.billingConfigLock.RUnlock() + + view, ok := c.GetBillingSubView() + if !ok { + return billing.DefaultBillingRetentionMonths, nil + } + + entry, err := view.Get(ctx, billing.BillingConfigPath) + if err != nil { + return 0, fmt.Errorf("failed to read billing config: %w", err) + } + if entry == nil { + // No config stored, return default + return billing.DefaultBillingRetentionMonths, nil + } + + retentionMonths, err := strconv.Atoi(string(entry.Value)) + if err != nil { + return 0, err + } + + return retentionMonths, nil +} + +func (c *Core) UpdateBillingRetentionMonths(ctx context.Context, retentionMonths int) error { + c.billingConfigLock.Lock() + defer c.billingConfigLock.Unlock() + + view, ok := c.GetBillingSubView() + if !ok { + return fmt.Errorf("billing sub view not available") + } + + entry := &logical.StorageEntry{ + Key: billing.BillingConfigPath, + Value: []byte(strconv.Itoa(retentionMonths)), + } + + if err := view.Put(ctx, entry); err != nil { + return fmt.Errorf("failed to store billing config: %w", err) + } + + return nil +} + // storeTransitCallCountsLocked must be called with BillingStorageLock held func (c *Core) storeTransitCallCountsLocked(ctx context.Context, transitCount uint64, localPathPrefix string, month time.Time) error { // Store count for each data protection type separately because they are atomic counters diff --git a/vault/consumption_billing_util_test.go b/vault/consumption_billing_util_test.go index 740ff141ae..79c2987b4a 100644 --- a/vault/consumption_billing_util_test.go +++ b/vault/consumption_billing_util_test.go @@ -1617,3 +1617,44 @@ func TestGcpKmsDataProtectionCallCounts(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(3), counts) } + +// TestCore_BillingRetentionMonths tests the GetBillingRetentionMonths and UpdateBillingRetentionMonths methods. +func TestCore_BillingRetentionMonths(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(context.Background()) + + // When no configuration is stored, should return default value + retentionMonths, err := core.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, billing.DefaultBillingRetentionMonths, retentionMonths) + + // Update to minimum value and verify + err = core.UpdateBillingRetentionMonths(ctx, billing.MinBillingRetentionMonths) + require.NoError(t, err) + retentionMonths, err = core.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, billing.MinBillingRetentionMonths, retentionMonths) + + // Update to maximum value and verify + err = core.UpdateBillingRetentionMonths(ctx, billing.MaxBillingRetentionMonths) + require.NoError(t, err) + retentionMonths, err = core.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, billing.MaxBillingRetentionMonths, retentionMonths) + + // Update to custom value and verify persistence + customRetention := 48 + err = core.UpdateBillingRetentionMonths(ctx, customRetention) + require.NoError(t, err) + retentionMonths, err = core.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, customRetention, retentionMonths) + + // Update to a different custom value and verify persistence + newRetention := 60 + err = core.UpdateBillingRetentionMonths(ctx, newRetention) + require.NoError(t, err) + retentionMonths, err = core.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, newRetention, retentionMonths) +} diff --git a/vault/core.go b/vault/core.go index 9eb0444f6d..72114ba948 100644 --- a/vault/core.go +++ b/vault/core.go @@ -467,6 +467,9 @@ type Core struct { // consumptionBillingLock protects the consumptionBilling struct consumptionBillingLock sync.RWMutex + // billingConfigLock protects billing configuration reads and writes + billingConfigLock sync.RWMutex + // consumptionBillingSubView is the sub-view of the system barrier view that is used to store consumption billing metrics consumptionBillingSubView *BarrierView diff --git a/vault/external_tests/api/sys_billing_test.go b/vault/external_tests/api/sys_billing_test.go index 18d46db70b..bf9d8ed6c5 100644 --- a/vault/external_tests/api/sys_billing_test.go +++ b/vault/external_tests/api/sys_billing_test.go @@ -83,7 +83,7 @@ func Test_BillingOverview(t *testing.T) { // Validate response structure require.NotNil(t, resp.Months) - require.Len(t, resp.Months, billing.BillingRetentionMonths, "should have billing.BillingRetentionMonths months") + require.Len(t, resp.Months, billing.DefaultBillingRetentionMonths, "should have billing.DefaultBillingRetentionMonths months") // Check current month data currentMonth := resp.Months[0] @@ -132,7 +132,7 @@ func Test_BillingOverview_WithoutUpdateCounts(t *testing.T) { // Validate basic response structure require.NotNil(t, resp.Months) - require.Len(t, resp.Months, billing.BillingRetentionMonths, "should have billing.BillingRetentionMonths months") + require.Len(t, resp.Months, billing.DefaultBillingRetentionMonths, "should have billing.DefaultBillingRetentionMonths months") // Check that months are properly formatted for _, month := range resp.Months { @@ -155,7 +155,7 @@ func Test_BillingOverview_EmptyCluster(t *testing.T) { require.NotNil(t, resp) require.NotNil(t, resp.Months) - require.Len(t, resp.Months, billing.BillingRetentionMonths) + require.Len(t, resp.Months, billing.DefaultBillingRetentionMonths) currentMonth := resp.Months[0] require.NotEmpty(t, currentMonth.Month) @@ -211,16 +211,121 @@ func Test_BillingOverview_MonthFormat(t *testing.T) { // Verify months are in descending order (current, then previous months) require.Greater(t, resp.Months[0].Month, resp.Months[1].Month, "first month should be more recent than second") - // Verify we have billing.BillingRetentionMonths months - require.Len(t, resp.Months, billing.BillingRetentionMonths) + // Verify we have billing.DefaultBillingRetentionMonths months + require.Len(t, resp.Months, billing.DefaultBillingRetentionMonths) - // Verify the oldest month is exactly (billing.BillingRetentionMonths - 1) months before the current month + // Verify the oldest month is exactly (billing.DefaultBillingRetentionMonths - 1) months before the current month currentMonthTime, err := time.Parse("2006-01", resp.Months[0].Month) require.NoError(t, err, "should parse current month") - oldestMonthTime, err := time.Parse("2006-01", resp.Months[billing.BillingRetentionMonths-1].Month) + oldestMonthTime, err := time.Parse("2006-01", resp.Months[billing.DefaultBillingRetentionMonths-1].Month) require.NoError(t, err, "should parse oldest month") - expectedOldestMonth := currentMonthTime.AddDate(0, -(billing.BillingRetentionMonths - 1), 0) + expectedOldestMonth := currentMonthTime.AddDate(0, -(billing.DefaultBillingRetentionMonths - 1), 0) require.Equal(t, expectedOldestMonth.Format("2006-01"), oldestMonthTime.Format("2006-01"), - "oldest month should be exactly %d months before current month", billing.BillingRetentionMonths-1) + "oldest month should be exactly %d months before current month", billing.DefaultBillingRetentionMonths-1) +} + +// Test_BillingConfig tests the GetBillingConfig and SetBillingConfig API methods +func Test_BillingConfig(t *testing.T) { + t.Parallel() + + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + // Test GetBillingConfig - should return default value + config, err := client.Sys().GetBillingConfig() + require.NoError(t, err) + require.NotNil(t, config) + require.Equal(t, billing.DefaultBillingRetentionMonths, config.RetentionMonths) + + // Test SetBillingConfig with valid value + err = client.Sys().SetBillingConfig(24) + require.NoError(t, err) + + // Verify the value was updated + config, err = client.Sys().GetBillingConfig() + require.NoError(t, err) + require.Equal(t, 24, config.RetentionMonths) + + // Test SetBillingConfig with minimum value + err = client.Sys().SetBillingConfig(billing.MinBillingRetentionMonths) + require.NoError(t, err) + + config, err = client.Sys().GetBillingConfig() + require.NoError(t, err) + require.Equal(t, billing.MinBillingRetentionMonths, config.RetentionMonths) + + // Test SetBillingConfig with maximum value + err = client.Sys().SetBillingConfig(billing.MaxBillingRetentionMonths) + require.NoError(t, err) + + config, err = client.Sys().GetBillingConfig() + require.NoError(t, err) + require.Equal(t, billing.MaxBillingRetentionMonths, config.RetentionMonths) +} + +// Test_BillingConfig_InvalidValues tests that invalid retention values are rejected +func Test_BillingConfig_InvalidValues(t *testing.T) { + t.Parallel() + + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + // Test below minimum + err := client.Sys().SetBillingConfig(billing.MinBillingRetentionMonths - 1) + require.Error(t, err) + require.Contains(t, err.Error(), "must be between") + + // Test above maximum + err = client.Sys().SetBillingConfig(billing.MaxBillingRetentionMonths + 1) + require.Error(t, err) + require.Contains(t, err.Error(), "must be between") + + // Test zero + err = client.Sys().SetBillingConfig(0) + require.Error(t, err) + require.Contains(t, err.Error(), "must be between") + + // Test negative + err = client.Sys().SetBillingConfig(-1) + require.Error(t, err) + require.Contains(t, err.Error(), "must be between") +} + +// Test_BillingConfig_AffectsOverview tests that config changes affect billing overview +func Test_BillingConfig_AffectsOverview(t *testing.T) { + t.Parallel() + + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + // Set retention to minimum + err := client.Sys().SetBillingConfig(billing.MinBillingRetentionMonths) + require.NoError(t, err) + + // Get billing overview + resp, err := client.Sys().BillingOverview(false) + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Months, billing.MinBillingRetentionMonths) + + // Set retention to maximum + err = client.Sys().SetBillingConfig(billing.MaxBillingRetentionMonths) + require.NoError(t, err) + + // Get billing overview again + resp, err = client.Sys().BillingOverview(false) + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Months, billing.MaxBillingRetentionMonths) + + // Set back to default + err = client.Sys().SetBillingConfig(billing.DefaultBillingRetentionMonths) + require.NoError(t, err) + + // Verify default is restored + resp, err = client.Sys().BillingOverview(false) + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Months, billing.DefaultBillingRetentionMonths) } diff --git a/vault/logical_system_use_case_billing.go b/vault/logical_system_use_case_billing.go index e3d4e9208c..bf7b7d8f71 100644 --- a/vault/logical_system_use_case_billing.go +++ b/vault/logical_system_use_case_billing.go @@ -27,56 +27,159 @@ const ( func (b *SystemBackend) useCaseConsumptionBillingPaths() []*framework.Path { return []*framework.Path{ - { - Pattern: "billing/overview$", - Fields: map[string]*framework.FieldSchema{ - "refresh_data": { - Type: framework.TypeBool, - Description: "If set, updates the billing counts for the current month before returning. This is an expensive operation with potential performance impact and should be used sparingly.", - Query: true, - }, - "start_month": { - Type: framework.TypeString, - Description: "Start month in YYYY-MM format (inclusive). If not specified, defaults to the oldest available month within BillingRetentionMonths.", - Query: true, - }, - "end_month": { - Type: framework.TypeString, - Description: "End month in YYYY-MM format (inclusive). If not specified, defaults to the current month.", - Query: true, - }, + b.billingOverviewPath(), + b.billingConfigPath(), + } +} + +func (b *SystemBackend) billingOverviewPath() *framework.Path { + return &framework.Path{ + Pattern: "billing/overview$", + Fields: map[string]*framework.FieldSchema{ + "refresh_data": { + Type: framework.TypeBool, + Description: "If set, updates the billing counts for the current month before returning. This is an expensive operation with potential performance impact and should be used sparingly.", + Query: true, }, - Operations: map[logical.Operation]framework.OperationHandler{ - logical.ReadOperation: &framework.PathOperation{ - Callback: b.handleUseCaseConsumption, - Summary: "Reports consumption billing metrics on a monthly granularity.", - Responses: map[int][]framework.Response{ - http.StatusOK: {{ - Description: http.StatusText(http.StatusOK), - Fields: map[string]*framework.FieldSchema{ - "months": { - Type: framework.TypeSlice, - Description: "List of monthly billing data.", - }, + "start_month": { + Type: framework.TypeString, + Description: "Start month in YYYY-MM format (inclusive). If not specified, defaults to the oldest available month within BillingRetentionMonths.", + Query: true, + }, + "end_month": { + Type: framework.TypeString, + Description: "End month in YYYY-MM format (inclusive). If not specified, defaults to the current month.", + Query: true, + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleBillingOverview, + Summary: "Reports consumption billing metrics on a monthly granularity.", + Responses: map[int][]framework.Response{ + http.StatusOK: {{ + Description: http.StatusText(http.StatusOK), + Fields: map[string]*framework.FieldSchema{ + "months": { + Type: framework.TypeSlice, + Description: "List of monthly billing data.", }, - }}, - http.StatusNoContent: {{ - Description: http.StatusText(http.StatusNoContent), - }}, - http.StatusBadRequest: {{ - Description: http.StatusText(http.StatusBadRequest), - }}, - http.StatusInternalServerError: {{ - Description: http.StatusText(http.StatusInternalServerError), - }}, - }, + }, + }}, + http.StatusNoContent: {{ + Description: http.StatusText(http.StatusNoContent), + }}, + http.StatusBadRequest: {{ + Description: http.StatusText(http.StatusBadRequest), + }}, + http.StatusInternalServerError: {{ + Description: http.StatusText(http.StatusInternalServerError), + }}, }, }, }, } } -func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { +func (b *SystemBackend) billingConfigPath() *framework.Path { + return &framework.Path{ + Pattern: "billing/config$", + Fields: map[string]*framework.FieldSchema{ + "retention_months": { + Type: framework.TypeInt, + Description: fmt.Sprintf("Number of months to retain billing data. Must be between %d and %d months. Defaults to %d months.", billing.MinBillingRetentionMonths, billing.MaxBillingRetentionMonths, billing.DefaultBillingRetentionMonths), + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleBillingConfigRead, + Summary: "Read the billing data retention configuration.", + Responses: map[int][]framework.Response{ + http.StatusOK: {{ + Description: http.StatusText(http.StatusOK), + Fields: map[string]*framework.FieldSchema{ + "retention_months": { + Type: framework.TypeInt, + Description: "Number of months of billing data to retain.", + }, + }, + }}, + http.StatusNoContent: {{ + Description: http.StatusText(http.StatusNoContent), + }}, + http.StatusBadRequest: {{ + Description: http.StatusText(http.StatusBadRequest), + }}, + http.StatusInternalServerError: {{ + Description: http.StatusText(http.StatusInternalServerError), + }}, + }, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.handleBillingConfigWrite, + Summary: "Configure the billing data retention period.", + Responses: map[int][]framework.Response{ + http.StatusOK: {{ + Description: http.StatusText(http.StatusOK), + }}, + http.StatusNoContent: {{ + Description: http.StatusText(http.StatusNoContent), + }}, + http.StatusBadRequest: {{ + Description: http.StatusText(http.StatusBadRequest), + }}, + http.StatusInternalServerError: {{ + Description: http.StatusText(http.StatusInternalServerError), + }}, + }, + }, + }, + } +} + +func (b *SystemBackend) handleBillingConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + retentionMonths, err := b.Core.GetBillingRetentionMonths(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get billing retention configuration: %w", err) + } + + return &logical.Response{ + Data: map[string]interface{}{ + "retention_months": retentionMonths, + }, + }, nil +} + +func (b *SystemBackend) handleBillingConfigWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + retentionMonths := data.Get("retention_months").(int) + if retentionMonths < billing.MinBillingRetentionMonths || retentionMonths > billing.MaxBillingRetentionMonths { + return logical.ErrorResponse(fmt.Sprintf("retention_months must be between %d and %d months", billing.MinBillingRetentionMonths, billing.MaxBillingRetentionMonths)), logical.ErrInvalidRequest + } + + // Get current retention to check if it's being increased + currentRetention, err := b.Core.GetBillingRetentionMonths(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current billing retention configuration: %w", err) + } + + // Store the configuration + if err := b.Core.UpdateBillingRetentionMonths(ctx, retentionMonths); err != nil { + return nil, fmt.Errorf("failed to set billing retention configuration: %w", err) + } + + resp := &logical.Response{} + + // Add warning if retention period is being increased + if retentionMonths > currentRetention { + resp.Warnings = append(resp.Warnings, fmt.Sprintf( + "Retention period increased from %d to %d months. Historical data will only be available for months within the previous retention period. Older months outside the previous retention range will not have data.", + currentRetention, retentionMonths)) + } + + return resp, nil +} + +func (b *SystemBackend) handleBillingOverview(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { refreshData := data.Get("refresh_data").(bool) currentMonth := time.Now().UTC() @@ -91,7 +194,13 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic refreshData = false } - startMonth, endMonth, isOutOfRetention, err := parseStartEndMonths(data, currentMonth) + // Get the configured retention period + retentionMonths, err := b.Core.GetBillingRetentionMonths(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get billing retention configuration: %w", err) + } + + startMonth, endMonth, isOutOfRetention, err := parseStartEndMonths(data, currentMonth, retentionMonths) if err != nil { return nil, err } @@ -132,10 +241,10 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic } // parseStartEndMonths parses the start and end month parameters from the request and validates if they are valid. -// If they are outside of the BillingRetentionMonths range, it returns a warning. If no parameter is specified, -// the start and end defaults to the start of the BillingRetentionMonths range and the current month, respectively. -func parseStartEndMonths(data *framework.FieldData, currentMonth time.Time) (time.Time, time.Time, bool, error) { - defaultStartMonth := timeutil.StartOfMonth(currentMonth).AddDate(0, -billing.BillingRetentionMonths+1, 0) +// If they are outside of the retention range, it returns a warning. If no parameter is specified, +// the start and end defaults to the start of the retention range and the current month, respectively. +func parseStartEndMonths(data *framework.FieldData, currentMonth time.Time, retentionMonths int) (time.Time, time.Time, bool, error) { + defaultStartMonth := timeutil.StartOfMonth(currentMonth).AddDate(0, -retentionMonths+1, 0) defaultEndMonth := timeutil.StartOfMonth(currentMonth) parseMonth := func(key string, defaultMonth time.Time) (time.Time, error) { diff --git a/vault/logical_system_use_case_billing_test.go b/vault/logical_system_use_case_billing_test.go index f8604a8aa7..93a99a4c0e 100644 --- a/vault/logical_system_use_case_billing_test.go +++ b/vault/logical_system_use_case_billing_test.go @@ -39,13 +39,13 @@ func TestSystemBackend_BillingOverviewMonthFormat(t *testing.T) { // Verify the response structure months, ok := resp.Data["months"].([]interface{}) require.True(t, ok, "months should be a slice") - require.Len(t, months, billing.BillingRetentionMonths, "should have billing.BillingRetentionMonths months") + require.Len(t, months, billing.DefaultBillingRetentionMonths, "should have billing.DefaultBillingRetentionMonths months") now := time.Now() currentMonthStart := timeutil.StartOfMonth(now) // Loop through all months and verify format - for i := 0; i < billing.BillingRetentionMonths; i++ { + for i := 0; i < billing.DefaultBillingRetentionMonths; i++ { monthData, ok := months[i].(map[string]interface{}) require.True(t, ok, "month %d should be a map", i) @@ -86,9 +86,9 @@ func TestSystemBackend_BillingOverview_StartEndMonthParams(t *testing.T) { previousMonth := timeutil.StartOfPreviousMonth(now).Format("2006-01") nextMonth := timeutil.StartOfNextMonth(now).Format("2006-01") twoMonthsAfterCurrent := timeutil.StartOfMonth(now).AddDate(0, 2, 0).Format("2006-01") - retentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths+1, 0).Format("2006-01") - beforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths, 0).Format("2006-01") - twoMonthsBeforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths-1, 0).Format("2006-01") + retentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.DefaultBillingRetentionMonths+1, 0).Format("2006-01") + beforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.DefaultBillingRetentionMonths, 0).Format("2006-01") + twoMonthsBeforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.DefaultBillingRetentionMonths-1, 0).Format("2006-01") testCases := []struct { name string @@ -107,20 +107,20 @@ func TestSystemBackend_BillingOverview_StartEndMonthParams(t *testing.T) { { name: "start before retention period, default end", startMonth: beforeRetentionStart, - expectedMonths: billing.BillingRetentionMonths + 1, + expectedMonths: billing.DefaultBillingRetentionMonths + 1, expectedWarning: WarningStartEndMonthOutOfRetentionRange, }, { name: "end after retention period, default start", endMonth: nextMonth, - expectedMonths: billing.BillingRetentionMonths + 1, + expectedMonths: billing.DefaultBillingRetentionMonths + 1, expectedWarning: WarningStartEndMonthOutOfRetentionRange, }, { name: "start is exactly start of retention period", startMonth: retentionStart, endMonth: previousMonth, - expectedMonths: billing.BillingRetentionMonths - 1, + expectedMonths: billing.DefaultBillingRetentionMonths - 1, }, { name: "start and end after retention period", @@ -143,7 +143,7 @@ func TestSystemBackend_BillingOverview_StartEndMonthParams(t *testing.T) { }, { name: "no parameters, default start and end", - expectedMonths: billing.BillingRetentionMonths, + expectedMonths: billing.DefaultBillingRetentionMonths, }, { name: "start after end", @@ -240,6 +240,102 @@ func TestSystemBackend_BillingOverview_StartEndMonthParams(t *testing.T) { } } +// TestSystemBackend_BillingOverview_StartEndMonthParams_CustomRetention tests that the +// billing overview endpoint correctly validates start_month and end_month parameters +// against a custom retention period. This verifies that the retention boundary logic +// uses the configured retention period rather than the default. +func TestSystemBackend_BillingOverview_StartEndMonthParams_CustomRetention(t *testing.T) { + c, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + // Set custom retention to 15 months + customRetention := 15 + err := c.UpdateBillingRetentionMonths(ctx, customRetention) + require.NoError(t, err) + + // Verify the custom retention was set + retentionMonths, err := c.GetBillingRetentionMonths(ctx) + require.NoError(t, err) + require.Equal(t, customRetention, retentionMonths) + + now := time.Now().UTC() + currentMonth := now.Format("2006-01") + previousMonth := timeutil.StartOfPreviousMonth(now).Format("2006-01") + + // With 15 months retention: + // - retentionStart is 14 months ago (oldest retained month) + // - beforeRetentionStart is 15 months ago (should trigger warning) + retentionStart := timeutil.StartOfMonth(now).AddDate(0, -(customRetention - 1), 0).Format("2006-01") + beforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -customRetention, 0).Format("2006-01") + twoMonthsBeforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -(customRetention + 1), 0).Format("2006-01") + + testCases := []struct { + name string + startMonth interface{} + endMonth interface{} + expectedMonths int + expectedWarning string + }{ + { + name: "start and end within custom retention period", + startMonth: previousMonth, + endMonth: currentMonth, + expectedMonths: 2, + }, + { + name: "start before custom retention period", + startMonth: beforeRetentionStart, + expectedMonths: customRetention + 1, + expectedWarning: WarningStartEndMonthOutOfRetentionRange, + }, + { + name: "start is exactly start of custom retention period", + startMonth: retentionStart, + endMonth: previousMonth, + expectedMonths: customRetention - 1, + }, + { + name: "start and end before custom retention period", + startMonth: twoMonthsBeforeRetentionStart, + endMonth: beforeRetentionStart, + expectedMonths: 2, + expectedWarning: WarningStartEndMonthOutOfRetentionRange, + }, + { + name: "no parameters with custom retention", + expectedMonths: customRetention, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + req := logical.TestRequest(t, logical.ReadOperation, "billing/overview") + if test.startMonth != nil { + req.Data["start_month"] = test.startMonth + } + if test.endMonth != nil { + req.Data["end_month"] = test.endMonth + } + + resp, err := b.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Check for expected warning + if test.expectedWarning != "" { + require.NotNil(t, resp.Warnings) + require.Contains(t, resp.Warnings, test.expectedWarning) + } + + // Verify the number of months returned matches custom retention + months, ok := resp.Data["months"].([]interface{}) + require.True(t, ok) + require.Len(t, months, test.expectedMonths, + "should return %d months with custom retention of %d", test.expectedMonths, customRetention) + }) + } +} + // TestSystemBackend_BillingOverview_WithMetrics tests the billing overview endpoint // with actual KV secrets created to generate billing metrics. It verifies that KV v2 // secrets are properly counted in billing, the static_secrets metric appears in the @@ -280,7 +376,7 @@ func TestSystemBackend_BillingOverview_WithMetrics(t *testing.T) { // Verify the response contains metrics months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) currentMonthData, ok := months[0].(map[string]interface{}) require.True(t, ok) @@ -323,7 +419,7 @@ func TestSystemBackend_BillingOverview_WithMetrics(t *testing.T) { // Verify that all previous months (without data) have empty usage_metrics currentMonthStart := timeutil.StartOfMonth(currentMonth) - for i := 1; i < billing.BillingRetentionMonths; i++ { + for i := 1; i < billing.DefaultBillingRetentionMonths; i++ { monthData, ok := months[i].(map[string]interface{}) require.True(t, ok, "month %d should be a map", i) @@ -503,7 +599,7 @@ func TestSystemBackend_BillingOverview_MetricTypeFormat(t *testing.T) { months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) currentMonthData, ok := months[0].(map[string]interface{}) require.True(t, ok) @@ -693,10 +789,10 @@ func TestSystemBackend_BillingOverview_HistoricalMonths(t *testing.T) { months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) // Loop through all months and verify timestamps - for i := 0; i < billing.BillingRetentionMonths; i++ { + for i := 0; i < billing.DefaultBillingRetentionMonths; i++ { monthData, ok := months[i].(map[string]interface{}) require.True(t, ok, "month %d should be a map", i) @@ -745,7 +841,7 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) { // Verify the response structure exists months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) // Check current month has all metrics with zero values currentMonth, ok := months[0].(map[string]interface{}) @@ -903,7 +999,7 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) { } // Verify all previous months also have zero values - for i := 1; i < billing.BillingRetentionMonths; i++ { + for i := 1; i < billing.DefaultBillingRetentionMonths; i++ { monthData, ok := months[i].(map[string]interface{}) require.True(t, ok, "month %d should be a map", i) require.Contains(t, monthData, "usage_metrics", "month %d should have usage_metrics", i) @@ -996,7 +1092,7 @@ func TestSystemBackend_BillingOverview_MultipleMetricTypes(t *testing.T) { months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) currentMonthData, ok := months[0].(map[string]interface{}) require.True(t, ok) @@ -1042,7 +1138,7 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) currentMonth, ok := months[0].(map[string]interface{}) require.True(t, ok) @@ -1059,7 +1155,7 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { require.Equal(t, firstTime, lastUpdate, "stored timestamp should match response timestamp") // Verify all previous months have zero timestamp (no data stored for them) - for i := 1; i < billing.BillingRetentionMonths; i++ { + for i := 1; i < billing.DefaultBillingRetentionMonths; i++ { prevMonth, ok := months[i].(map[string]interface{}) require.True(t, ok, "month %d should be a map", i) @@ -1085,7 +1181,7 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { months, ok = resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) currentMonth, ok = months[0].(map[string]interface{}) require.True(t, ok) @@ -1105,7 +1201,7 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp(t *testing.T) { "updated_at without refresh should be identical to the stored timestamp") // Verify all previous months' timestamps remain the same (zero time) - for i := 1; i < billing.BillingRetentionMonths; i++ { + for i := 1; i < billing.DefaultBillingRetentionMonths; i++ { prevMonth, ok := months[i].(map[string]interface{}) require.True(t, ok, "month %d should be a map", i) @@ -1139,7 +1235,7 @@ func TestSystemBackend_BillingOverview_UpdatedAtTimestamp_NoStoredTimestamp(t *t months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) currentMonth, ok := months[0].(map[string]interface{}) require.True(t, ok) @@ -1184,7 +1280,7 @@ func TestSystemBackend_BillingOverview_AllMetricTypesPresent(t *testing.T) { // Verify the response structure exists months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) // Check current month has all metrics currentMonth, ok := months[0].(map[string]interface{}) @@ -1321,7 +1417,7 @@ func TestSystemBackend_BillingOverview_PreviousMonth_WithError(t *testing.T) { months, ok := resp.Data["months"].([]interface{}) require.True(t, ok) - require.Len(t, months, billing.BillingRetentionMonths) + require.Len(t, months, billing.DefaultBillingRetentionMonths) // Check previous month data previousMonthData, ok := months[1].(map[string]interface{}) @@ -1590,3 +1686,164 @@ func TestRoundUsageMetrics(t *testing.T) { }) } } + +// TestSystemBackend_BillingConfig_Read tests reading the billing retention configuration +func TestSystemBackend_BillingConfig_Read(t *testing.T) { + _, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + // Read config when not set - should return default + req := logical.TestRequest(t, logical.ReadOperation, "billing/config") + resp, err := b.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + require.Equal(t, billing.DefaultBillingRetentionMonths, resp.Data["retention_months"]) +} + +// TestSystemBackend_BillingConfig_Write tests writing the billing retention configuration +func TestSystemBackend_BillingConfig_Write(t *testing.T) { + _, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + testCases := []struct { + name string + retentionMonths int + expectError bool + errorContains string + expectWarning bool + warningContains string + }{ + { + name: "valid minimum value", + retentionMonths: billing.MinBillingRetentionMonths, + expectError: false, + expectWarning: false, // Less than default, no warning + }, + { + name: "valid maximum value", + retentionMonths: billing.MaxBillingRetentionMonths, + expectError: false, + expectWarning: true, // Greater than default (37), should warn + warningContains: "Retention period increased", + }, + { + name: "valid middle value below default", + retentionMonths: 24, + expectError: false, + expectWarning: false, // Less than default, no warning + }, + { + name: "valid middle value above default", + retentionMonths: 48, + expectError: false, + expectWarning: true, // Greater than default (37), should warn + warningContains: "Retention period increased", + }, + { + name: "below minimum", + retentionMonths: billing.MinBillingRetentionMonths - 1, + expectError: true, + errorContains: "must be between", + }, + { + name: "above maximum", + retentionMonths: billing.MaxBillingRetentionMonths + 1, + expectError: true, + errorContains: "must be between", + }, + { + name: "zero value", + retentionMonths: 0, + expectError: true, + errorContains: "must be between", + }, + { + name: "negative value", + retentionMonths: -1, + expectError: true, + errorContains: "must be between", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := logical.TestRequest(t, logical.UpdateOperation, "billing/config") + req.Data = map[string]interface{}{ + "retention_months": tc.retentionMonths, + } + + resp, err := b.HandleRequest(ctx, req) + + if tc.expectError { + require.Equal(t, logical.ErrInvalidRequest, err) + require.NotNil(t, resp) + require.True(t, resp.IsError()) + require.Contains(t, resp.Error().Error(), tc.errorContains) + } else { + require.NoError(t, err) + require.NotNil(t, resp) + require.False(t, resp.IsError()) + + // Check for warning when increasing retention + if tc.expectWarning { + require.Len(t, resp.Warnings, 1, "should have warning when increasing retention above default") + require.Contains(t, resp.Warnings[0], tc.warningContains) + } else { + require.Empty(t, resp.Warnings, "should not have warnings when not increasing retention") + } + + // Verify the value was stored by reading it back + readReq := logical.TestRequest(t, logical.ReadOperation, "billing/config") + readResp, err := b.HandleRequest(ctx, readReq) + require.NoError(t, err) + require.NotNil(t, readResp) + require.Equal(t, tc.retentionMonths, readResp.Data["retention_months"]) + } + }) + } +} + +// TestSystemBackend_BillingConfig_AffectsOverview tests that config affects billing overview +func TestSystemBackend_BillingConfig_AffectsOverview(t *testing.T) { + _, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + // Set retention to minimum (13 months) + writeReq := logical.TestRequest(t, logical.UpdateOperation, "billing/config") + writeReq.Data = map[string]interface{}{ + "retention_months": billing.MinBillingRetentionMonths, + } + _, err := b.HandleRequest(ctx, writeReq) + require.NoError(t, err) + + // Request billing overview + overviewReq := logical.TestRequest(t, logical.ReadOperation, "billing/overview") + resp, err := b.HandleRequest(ctx, overviewReq) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify we get 13 months of data + months, ok := resp.Data["months"].([]interface{}) + require.True(t, ok) + require.Len(t, months, billing.MinBillingRetentionMonths) + + // Set retention to maximum (72 months) + writeReq = logical.TestRequest(t, logical.UpdateOperation, "billing/config") + writeReq.Data = map[string]interface{}{ + "retention_months": billing.MaxBillingRetentionMonths, + } + _, err = b.HandleRequest(ctx, writeReq) + require.NoError(t, err) + + // Request billing overview again + overviewReq = logical.TestRequest(t, logical.ReadOperation, "billing/overview") + resp, err = b.HandleRequest(ctx, overviewReq) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify we get 72 months of data + months, ok = resp.Data["months"].([]interface{}) + require.True(t, ok) + require.Len(t, months, billing.MaxBillingRetentionMonths) +}