From 9cfcfec78a5b60d4b0ebf51e7aad733c3983061e Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Fri, 13 Feb 2026 09:41:37 -0500 Subject: [PATCH] Count duration adjusted certificate counts for billing (#12286) (#12310) * Change PkiCertificateCountManager.GetCounts() to return a CertCount. * Add PkiDurationAdjustedCerts field to CertCount. Add a new field to CertCount to keep track of "duration adjusted" issued certificates. Add an x509.Certificate argument to CertCountIncrementer.AddIssuedCertificate. In the implementation, use the certificate's NotBefore and NotAfter fields to calculate the validity duration for the certificate, and use that to compute the duration adjusted units. * Add the issued certificate to calls to AddIssuedCertificate. * Add PkiDurationAdjustedCerts when forwarding counts. Add pki_duration_adjusted_certificate_count to IncrementPkiCount proto. Update replicationServiceHandler.IncrementPkiCertCountRequest to take into account the new field. * Run make proto. * Update testingPkiCertificateCounter to make assertions on time adjusted counts. * PR review: Don't use NotAfter.Sub(NotBefore), since time.Duration is max 290 years. * PR review: Move DurationAdjustedCertificateCount to logical.pki/test_helpers. Add Bob generated unit tests for logical.durationAdjustedCertificateCount. * Run make fmt. Co-authored-by: Victor Rodriguez Rizo --- builtin/logical/pki/path_acme_order.go | 2 +- builtin/logical/pki/path_issue_sign.go | 2 +- builtin/logical/pki/path_root.go | 4 +- builtin/logical/pki/test_helpers.go | 14 +- sdk/logical/certificate_counter.go | 47 +++- sdk/logical/certificate_counter_test.go | 222 ++++++++++++++++++ .../pki_cert_count/pki_cert_count_manager.go | 15 +- .../pki_cert_count_manager_test.go | 46 +++- vault/testing.go | 2 +- vault/testing_util.go | 17 +- 10 files changed, 341 insertions(+), 30 deletions(-) create mode 100644 sdk/logical/certificate_counter_test.go diff --git a/builtin/logical/pki/path_acme_order.go b/builtin/logical/pki/path_acme_order.go index 089589c81f..a46151aba5 100644 --- a/builtin/logical/pki/path_acme_order.go +++ b/builtin/logical/pki/path_acme_order.go @@ -297,7 +297,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request, if err != nil { return nil, err } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, signedCertBundle.Certificate) } hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber) diff --git a/builtin/logical/pki/path_issue_sign.go b/builtin/logical/pki/path_issue_sign.go index 7d1e61e494..1c6e05f481 100644 --- a/builtin/logical/pki/path_issue_sign.go +++ b/builtin/logical/pki/path_issue_sign.go @@ -484,7 +484,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d } } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore, parsedBundle.Certificate) if useCSR { if role.UseCSRCommonName && data.Get("common_name").(string) != "" { diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index 2b4940374e..62c051be05 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -308,7 +308,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, if err != nil { return nil, err } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, parsedBundle.Certificate) // Build a fresh CRL warnings, err = b.CrlBuilder().Rebuild(sc, true) @@ -462,7 +462,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R if err != nil { return nil, err } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, parsedBundle.Certificate) if warnAboutTruncate && signingBundle.Certificate.NotAfter.Equal(parsedBundle.Certificate.NotAfter) { diff --git a/builtin/logical/pki/test_helpers.go b/builtin/logical/pki/test_helpers.go index 781f51a048..9da615ff1d 100644 --- a/builtin/logical/pki/test_helpers.go +++ b/builtin/logical/pki/test_helpers.go @@ -525,11 +525,21 @@ func (c *testingPkiCertificateCounter) Increment() logical.CertCountIncrementer return logical.NewCertCountIncrementer(c) } -func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64) { +func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64, pkiDurationAdjustedCerts float64) { require.Equal(t, issuedCerts, c.count.IssuedCerts, "issued certificates count mismatch %s") require.Equal(t, storedCerts, c.count.StoredCerts, "stored certificates count mismatch %s") + require.Equal(t, pkiDurationAdjustedCerts, c.count.PkiDurationAdjustedCerts, "pki duration adjusted certificates count mismatch %s") } func (c *testingPkiCertificateCounter) RequireZero(t require.TestingT) { - c.RequireCount(t, 0, 0) + c.RequireCount(t, 0, 0, 0) +} + +// See logical.durationAdjustedCertificateCount for the billing specification and implementation details. +func adjustedCertificateCountFromDuration(validity time.Duration) float64 { + const standardDuration = 730.0 + validityHours := validity.Hours() + units := validityHours / standardDuration + // Round to 4 decimal places + return math.Round(units*10000) / 10000 } diff --git a/sdk/logical/certificate_counter.go b/sdk/logical/certificate_counter.go index 4633a40ce6..0504c06d4d 100644 --- a/sdk/logical/certificate_counter.go +++ b/sdk/logical/certificate_counter.go @@ -3,6 +3,11 @@ package logical +import ( + "crypto/x509" + "math" +) + // CertificateCounter is an interface for incrementing the count of issued and stored // certificates. type CertificateCounter interface { @@ -16,21 +21,43 @@ type CertificateCounter interface { // CertCount represents the parameters for incrementing certificate counts. type CertCount struct { - IssuedCerts uint64 + IssuedCerts uint64 // TODO(victorr): Rename to PkiIssuedCerts StoredCerts uint64 + + // PkiDurationAdjustedCerts tracks the normalized certificate duration units for billing + // purposes. Each certificate's billable units = (Validity Hours ÷ 730), rounded to 4 decimal + // places. + PkiDurationAdjustedCerts float64 } func (i *CertCount) Add(other CertCount) { i.IssuedCerts += other.IssuedCerts i.StoredCerts += other.StoredCerts + i.PkiDurationAdjustedCerts += other.PkiDurationAdjustedCerts } func (i *CertCount) IsZero() bool { - return i.IssuedCerts == 0 && i.StoredCerts == 0 + return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0 +} + +// durationAdjustedCertificateCount calculates the billable units for a certificate based on its +// validity duration. WARNING: Beware the maximum value for time.Duration (approximately 290 years). +// +// The calculation follows the billing specification: +// - Standard duration is 730 hours (1 month) +// - Units = (Validity Hours ÷ 730), rounded to 4 decimal places +// - Example: 1-year cert (8760 hours) = 12.0000 units +// - Example: 1-day cert (24 hours) = 0.0329 units +func durationAdjustedCertificateCount(validitySeconds int64) float64 { + const standardDuration = 730.0 + validityHours := float64(validitySeconds) / 3600.0 + units := validityHours / standardDuration + // Round to 4 decimal places + return math.Round(units*10000) / 10000 } type CertCountIncrementer interface { - AddIssuedCertificate(stored bool) CertCountIncrementer + AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer } type certCountIncrementer struct { @@ -44,10 +71,16 @@ func NewCertCountIncrementer(counter CertificateCounter) CertCountIncrementer { return &certCountIncrementer{counter: counter} } -// AddIssuedCertificate increments the issued certificate count by 1, and also the -// stored certificate count if stored is true. -func (c *certCountIncrementer) AddIssuedCertificate(stored bool) CertCountIncrementer { - count := CertCount{IssuedCerts: 1} +// AddIssuedCertificate increments the issued certificate count by 1, the stored certificate +// count if stored is true, and adds the calculated billable units based on the certificate's +// validity duration. +// cert: The X.509 certificate to extract validity duration from. +func (c *certCountIncrementer) AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer { + validity := int64(cert.NotAfter.Unix() - cert.NotBefore.Unix()) + count := CertCount{ + IssuedCerts: 1, + PkiDurationAdjustedCerts: durationAdjustedCertificateCount(validity), + } if stored { count.StoredCerts = 1 } diff --git a/sdk/logical/certificate_counter_test.go b/sdk/logical/certificate_counter_test.go new file mode 100644 index 0000000000..4d449878bf --- /dev/null +++ b/sdk/logical/certificate_counter_test.go @@ -0,0 +1,222 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +package logical + +import ( + "math" + "testing" +) + +func Test_durationAdjustedCertificateCount(t *testing.T) { + tests := []struct { + name string + validitySeconds int64 + want float64 + }{ + { + name: "zero duration", + validitySeconds: 0, + want: 0.0, + }, + { + name: "1 hour", + validitySeconds: 3600, + want: 0.0014, // 1/730 rounded to 4 decimals + }, + { + name: "24 hours (1 day)", + validitySeconds: 86400, // 24 * 3600 + want: 0.0329, // 24/730 = 0.032876... rounded to 4 decimals + }, + { + name: "730 hours (standard duration - 1 month)", + validitySeconds: 2628000, // 730 * 3600 + want: 1.0, + }, + { + name: "8760 hours (1 year)", + validitySeconds: 31536000, // 365 * 24 * 3600 + want: 12.0, // 8760/730 = 12.0 + }, + { + name: "17520 hours (2 years)", + validitySeconds: 63072000, // 730 * 24 * 3600 + want: 24.0, // 17520/730 = 24.0 + }, + { + name: "87600 hours (10 years)", + validitySeconds: 315360000, // 3650 * 24 * 3600 + want: 120.0, // 87600/730 = 120.0 + }, + { + name: "90 days", + validitySeconds: 7776000, // 90 * 24 * 3600 + want: 2.9589, // 2160/730 = 2.958904... rounded to 4 decimals + }, + { + name: "365 days (1 year)", + validitySeconds: 31536000, // 365 * 24 * 3600 + want: 12.0, // 8760/730 = 12.0 + }, + { + name: "fractional result - 100 hours", + validitySeconds: 360000, // 100 * 3600 + want: 0.137, // 100/730 = 0.136986... rounded to 4 decimals + }, + { + name: "fractional result - 500 hours", + validitySeconds: 1800000, // 500 * 3600 + want: 0.6849, // 500/730 = 0.684931... rounded to 4 decimals + }, + { + name: "very small duration - 1 second", + validitySeconds: 1, + want: 0.0, // 1/3600/730 = 0.00000038... rounds to 0.0 + }, + { + name: "very small duration - 60 seconds", + validitySeconds: 60, + want: 0.0, // 60/3600/730 = 0.000023... rounds to 0.0 + }, + { + name: "very small duration - 600 seconds", + validitySeconds: 600, + want: 0.0002, // 600/3600/730 = 0.000228... rounds to 0.0002 + }, + { + name: "edge case - exactly rounds up", + validitySeconds: 13149, // Should result in value that rounds up + want: 0.0050, // 3.6525/730 = 0.005003... rounds to 0.0050 + }, + { + name: "edge case - exactly rounds down", + validitySeconds: 13140, // Should result in value that rounds down + want: 0.0050, // 3.65/730 = 0.005000 + }, + { + name: "large duration - 100 years", + validitySeconds: 3153600000, // 100 * 365 * 24 * 3600 + want: 1200.0, // 876000/730 = 1200.0 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := durationAdjustedCertificateCount(tt.validitySeconds) + if got != tt.want { + t.Errorf("durationAdjustedCertificateCount(%d) = %v, want %v", tt.validitySeconds, got, tt.want) + } + }) + } +} + +func Test_durationAdjustedCertificateCount_Precision(t *testing.T) { + // Test that the function properly rounds to 4 decimal places + tests := []struct { + name string + validitySeconds int64 + wantPrecision int // number of decimal places + }{ + { + name: "result has max 4 decimal places - case 1", + validitySeconds: 12345, + wantPrecision: 4, + }, + { + name: "result has max 4 decimal places - case 2", + validitySeconds: 98765, + wantPrecision: 4, + }, + { + name: "result has max 4 decimal places - case 3", + validitySeconds: 555555, + wantPrecision: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := durationAdjustedCertificateCount(tt.validitySeconds) + // Check that the result has at most 4 decimal places + // by multiplying by 10000 and checking if it's an integer + scaled := got * 10000 + if scaled != math.Floor(scaled) { + t.Errorf("durationAdjustedCertificateCount(%d) = %v has more than 4 decimal places", tt.validitySeconds, got) + } + }) + } +} + +func Test_durationAdjustedCertificateCount_Consistency(t *testing.T) { + // Test that the function is consistent with the public wrapper + tests := []struct { + name string + validitySeconds int64 + }{ + {"1 hour", 3600}, + {"1 day", 86400}, + {"1 month", 2628000}, + {"1 year", 31536000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the internal function + internal := durationAdjustedCertificateCount(tt.validitySeconds) + + // The internal function should produce the same result as calculating manually + validityHours := float64(tt.validitySeconds) / 3600.0 + units := validityHours / 730.0 + expected := math.Round(units*10000) / 10000 + + if internal != expected { + t.Errorf("durationAdjustedCertificateCount(%d) = %v, manual calculation = %v", tt.validitySeconds, internal, expected) + } + }) + } +} + +func Test_durationAdjustedCertificateCount_NegativeInput(t *testing.T) { + // Test behavior with negative input (edge case) + // Note: In practice, this shouldn't happen with valid certificates, + // but we should verify the function's behavior + validitySeconds := int64(-3600) + got := durationAdjustedCertificateCount(validitySeconds) + + // The function should handle negative values mathematically + // -1 hour / 730 hours = -0.0014 (rounded to 4 decimals) + want := -0.0014 + + if got != want { + t.Errorf("durationAdjustedCertificateCount(%d) = %v, want %v", validitySeconds, got, want) + } +} + +func Benchmark_durationAdjustedCertificateCount(b *testing.B) { + validitySeconds := int64(31536000) // 1 year + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = durationAdjustedCertificateCount(validitySeconds) + } +} + +func Benchmark_durationAdjustedCertificateCount_Various(b *testing.B) { + testCases := []int64{ + 3600, // 1 hour + 86400, // 1 day + 2628000, // 1 month + 31536000, // 1 year + 315360000, // 10 years + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, validitySeconds := range testCases { + _ = durationAdjustedCertificateCount(validitySeconds) + } + } +} + +// Made with Bob diff --git a/vault/pki_cert_count/pki_cert_count_manager.go b/vault/pki_cert_count/pki_cert_count_manager.go index 503ca60fec..8926d7c64d 100644 --- a/vault/pki_cert_count/pki_cert_count_manager.go +++ b/vault/pki_cert_count/pki_cert_count_manager.go @@ -4,6 +4,7 @@ package pki_cert_count import ( + "crypto/x509" "os" "sync" "time" @@ -36,7 +37,7 @@ type PkiCertificateCountManager interface { // GetCounts returns the current counts of issued and stored certificates, without // consuming them. Meant to ease unit testing. - GetCounts() (issuedCount, storedCount uint64) + GetCounts() logical.CertCount } // certCountManager is an implementation of PkiCertificateCountManager. @@ -139,10 +140,12 @@ func (m *certCountManager) Increment() logical.CertCountIncrementer { return logical.NewCertCountIncrementer(m) } -func (m *certCountManager) GetCounts() (issuedCount, storedCount uint64) { +func (m *certCountManager) GetCounts() (issuedCount logical.CertCount) { m.countLock.RLock() defer m.countLock.RUnlock() - return m.count.IssuedCerts, m.count.StoredCerts + ret := logical.CertCount{} + ret.Add(m.count) + return ret } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -164,7 +167,7 @@ func (n *nullPkiCertificateCountManager) Increment() logical.CertCountIncremente return logical.NewCertCountIncrementer(n) } -func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool) { +func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool, _ *x509.Certificate) { // nothing to do } @@ -176,6 +179,6 @@ func (n *nullPkiCertificateCountManager) StopConsumerJob() { // nothing to do } -func (n *nullPkiCertificateCountManager) GetCounts() (issuedCount, storedCount uint64) { - return 0, 0 +func (n *nullPkiCertificateCountManager) GetCounts() (issuedCount logical.CertCount) { + return logical.CertCount{} } diff --git a/vault/pki_cert_count/pki_cert_count_manager_test.go b/vault/pki_cert_count/pki_cert_count_manager_test.go index 62a9b2a889..6aee5f74d8 100644 --- a/vault/pki_cert_count/pki_cert_count_manager_test.go +++ b/vault/pki_cert_count/pki_cert_count_manager_test.go @@ -4,6 +4,11 @@ package pki_cert_count import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" "sync" "sync/atomic" "testing" @@ -14,6 +19,34 @@ import ( "github.com/stretchr/testify/require" ) +// createTestCertificate creates a test certificate with the specified validity duration +func createTestCertificate(t *testing.T, validity time.Duration) *x509.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + notBefore := time.Now() + notAfter := notBefore.Add(validity) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "test-cert", + }, + NotBefore: notBefore, + NotAfter: notAfter, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + return cert +} + // TestPkiCertificateCountManager_IncrementAndConsume tests the behaviour of // PkiCertificateCountManager. func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { @@ -41,8 +74,15 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { manager.AddCount(logical.CertCount{IssuedCerts: 3, StoredCerts: 0}) manager.AddCount(logical.CertCount{IssuedCerts: 0, StoredCerts: 5}) - manager.Increment().AddIssuedCertificate(true) - manager.Increment().AddIssuedCertificate(false) + + // Create test certificates with different validity periods + // 730 hours = 1 month = 1.0 billable unit + cert1Month := createTestCertificate(t, 730*time.Hour) + // 8760 hours = 1 year = 12.0 billable units + cert1Year := createTestCertificate(t, 8760*time.Hour) + + manager.Increment().AddIssuedCertificate(true, cert1Month) + manager.Increment().AddIssuedCertificate(false, cert1Year) time.Sleep(100 * time.Millisecond) @@ -54,5 +94,7 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { require.Equal(t, uint64(5), jobCount.IssuedCerts, "issued count mismatch") require.Equal(t, uint64(6), jobCount.StoredCerts, "stored count mismatch") + // cert1Month: 730/730 = 1.0, cert1Year: 8760/730 = 12.0, total = 13.0 + require.InDelta(t, 13.0, jobCount.PkiDurationAdjustedCerts, 0.0001, "billable units mismatch") require.Zero(t, firstConsumerTotalCount.Load(), "first consumer should not have been called") } diff --git a/vault/testing.go b/vault/testing.go index 0fc4615fe7..8d7d1fb9d2 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1073,7 +1073,7 @@ type TestClusterCore struct { UnderlyingHAStorage physical.HABackend Barrier SecurityBarrier NodeID string - pkiCertificateCountData struct{ ignoredIssuedCount, ignoredStoredCount uint64 } + pkiCertificateCountData logical.CertCount } type PhysicalBackendBundle struct { diff --git a/vault/testing_util.go b/vault/testing_util.go index 0487759737..1a1352eb9f 100644 --- a/vault/testing_util.go +++ b/vault/testing_util.go @@ -29,20 +29,21 @@ func (c *TestClusterCore) StopPkiCertificateCountConsumerJob() { func (c *TestClusterCore) ResetPkiCertificateCounts() { mgr := c.Core.pkiCertCountManager.(pki_cert_count.PkiCertificateCountManager) - c.pkiCertificateCountData.ignoredIssuedCount, c.pkiCertificateCountData.ignoredStoredCount = mgr.GetCounts() + + c.pkiCertificateCountData = mgr.GetCounts() } func (c *TestClusterCore) RequirePkiCertificateCounts(t testing.TB, expectedIssuedCount, expectedStoredCount int) { t.Helper() mgr := c.Core.pkiCertCountManager.(pki_cert_count.PkiCertificateCountManager) - actualIssuedCount, actualStoredCount := mgr.GetCounts() + actualCount := mgr.GetCounts() - actualIssuedCount -= c.pkiCertificateCountData.ignoredIssuedCount - actualStoredCount -= c.pkiCertificateCountData.ignoredStoredCount + actualCount.IssuedCerts -= c.pkiCertificateCountData.IssuedCerts + actualCount.StoredCerts -= c.pkiCertificateCountData.StoredCerts - c.pkiCertificateCountData.ignoredIssuedCount += uint64(expectedIssuedCount) - c.pkiCertificateCountData.ignoredStoredCount += uint64(expectedStoredCount) + c.pkiCertificateCountData.IssuedCerts += uint64(expectedIssuedCount) + c.pkiCertificateCountData.StoredCerts += uint64(expectedStoredCount) - require.Equal(t, expectedIssuedCount, int(actualIssuedCount), "PKI certificate issued count mismatch") - require.Equal(t, expectedStoredCount, int(actualStoredCount), "PKI certificate stored count mismatch") + require.Equal(t, expectedIssuedCount, int(actualCount.IssuedCerts), "PKI certificate issued count mismatch") + require.Equal(t, expectedStoredCount, int(actualCount.StoredCerts), "PKI certificate stored count mismatch") }