diff --git a/builtin/logical/pki/path_acme_order.go b/builtin/logical/pki/path_acme_order.go index 089589c81f..a46151aba5 100644 --- a/builtin/logical/pki/path_acme_order.go +++ b/builtin/logical/pki/path_acme_order.go @@ -297,7 +297,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request, if err != nil { return nil, err } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, signedCertBundle.Certificate) } hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber) diff --git a/builtin/logical/pki/path_issue_sign.go b/builtin/logical/pki/path_issue_sign.go index 7d1e61e494..1c6e05f481 100644 --- a/builtin/logical/pki/path_issue_sign.go +++ b/builtin/logical/pki/path_issue_sign.go @@ -484,7 +484,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d } } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore, parsedBundle.Certificate) if useCSR { if role.UseCSRCommonName && data.Get("common_name").(string) != "" { diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index 2b4940374e..62c051be05 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -308,7 +308,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, if err != nil { return nil, err } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, parsedBundle.Certificate) // Build a fresh CRL warnings, err = b.CrlBuilder().Rebuild(sc, true) @@ -462,7 +462,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R if err != nil { return nil, err } - b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, parsedBundle.Certificate) if warnAboutTruncate && signingBundle.Certificate.NotAfter.Equal(parsedBundle.Certificate.NotAfter) { diff --git a/builtin/logical/pki/test_helpers.go b/builtin/logical/pki/test_helpers.go index 781f51a048..9da615ff1d 100644 --- a/builtin/logical/pki/test_helpers.go +++ b/builtin/logical/pki/test_helpers.go @@ -525,11 +525,21 @@ func (c *testingPkiCertificateCounter) Increment() logical.CertCountIncrementer return logical.NewCertCountIncrementer(c) } -func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64) { +func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64, pkiDurationAdjustedCerts float64) { require.Equal(t, issuedCerts, c.count.IssuedCerts, "issued certificates count mismatch %s") require.Equal(t, storedCerts, c.count.StoredCerts, "stored certificates count mismatch %s") + require.Equal(t, pkiDurationAdjustedCerts, c.count.PkiDurationAdjustedCerts, "pki duration adjusted certificates count mismatch %s") } func (c *testingPkiCertificateCounter) RequireZero(t require.TestingT) { - c.RequireCount(t, 0, 0) + c.RequireCount(t, 0, 0, 0) +} + +// See logical.durationAdjustedCertificateCount for the billing specification and implementation details. +func adjustedCertificateCountFromDuration(validity time.Duration) float64 { + const standardDuration = 730.0 + validityHours := validity.Hours() + units := validityHours / standardDuration + // Round to 4 decimal places + return math.Round(units*10000) / 10000 } diff --git a/sdk/logical/certificate_counter.go b/sdk/logical/certificate_counter.go index 4633a40ce6..0504c06d4d 100644 --- a/sdk/logical/certificate_counter.go +++ b/sdk/logical/certificate_counter.go @@ -3,6 +3,11 @@ package logical +import ( + "crypto/x509" + "math" +) + // CertificateCounter is an interface for incrementing the count of issued and stored // certificates. type CertificateCounter interface { @@ -16,21 +21,43 @@ type CertificateCounter interface { // CertCount represents the parameters for incrementing certificate counts. type CertCount struct { - IssuedCerts uint64 + IssuedCerts uint64 // TODO(victorr): Rename to PkiIssuedCerts StoredCerts uint64 + + // PkiDurationAdjustedCerts tracks the normalized certificate duration units for billing + // purposes. Each certificate's billable units = (Validity Hours ÷ 730), rounded to 4 decimal + // places. + PkiDurationAdjustedCerts float64 } func (i *CertCount) Add(other CertCount) { i.IssuedCerts += other.IssuedCerts i.StoredCerts += other.StoredCerts + i.PkiDurationAdjustedCerts += other.PkiDurationAdjustedCerts } func (i *CertCount) IsZero() bool { - return i.IssuedCerts == 0 && i.StoredCerts == 0 + return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0 +} + +// durationAdjustedCertificateCount calculates the billable units for a certificate based on its +// validity duration. WARNING: Beware the maximum value for time.Duration (approximately 290 years). +// +// The calculation follows the billing specification: +// - Standard duration is 730 hours (1 month) +// - Units = (Validity Hours ÷ 730), rounded to 4 decimal places +// - Example: 1-year cert (8760 hours) = 12.0000 units +// - Example: 1-day cert (24 hours) = 0.0329 units +func durationAdjustedCertificateCount(validitySeconds int64) float64 { + const standardDuration = 730.0 + validityHours := float64(validitySeconds) / 3600.0 + units := validityHours / standardDuration + // Round to 4 decimal places + return math.Round(units*10000) / 10000 } type CertCountIncrementer interface { - AddIssuedCertificate(stored bool) CertCountIncrementer + AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer } type certCountIncrementer struct { @@ -44,10 +71,16 @@ func NewCertCountIncrementer(counter CertificateCounter) CertCountIncrementer { return &certCountIncrementer{counter: counter} } -// AddIssuedCertificate increments the issued certificate count by 1, and also the -// stored certificate count if stored is true. -func (c *certCountIncrementer) AddIssuedCertificate(stored bool) CertCountIncrementer { - count := CertCount{IssuedCerts: 1} +// AddIssuedCertificate increments the issued certificate count by 1, the stored certificate +// count if stored is true, and adds the calculated billable units based on the certificate's +// validity duration. +// cert: The X.509 certificate to extract validity duration from. +func (c *certCountIncrementer) AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer { + validity := int64(cert.NotAfter.Unix() - cert.NotBefore.Unix()) + count := CertCount{ + IssuedCerts: 1, + PkiDurationAdjustedCerts: durationAdjustedCertificateCount(validity), + } if stored { count.StoredCerts = 1 } diff --git a/sdk/logical/certificate_counter_test.go b/sdk/logical/certificate_counter_test.go new file mode 100644 index 0000000000..4d449878bf --- /dev/null +++ b/sdk/logical/certificate_counter_test.go @@ -0,0 +1,222 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +package logical + +import ( + "math" + "testing" +) + +func Test_durationAdjustedCertificateCount(t *testing.T) { + tests := []struct { + name string + validitySeconds int64 + want float64 + }{ + { + name: "zero duration", + validitySeconds: 0, + want: 0.0, + }, + { + name: "1 hour", + validitySeconds: 3600, + want: 0.0014, // 1/730 rounded to 4 decimals + }, + { + name: "24 hours (1 day)", + validitySeconds: 86400, // 24 * 3600 + want: 0.0329, // 24/730 = 0.032876... rounded to 4 decimals + }, + { + name: "730 hours (standard duration - 1 month)", + validitySeconds: 2628000, // 730 * 3600 + want: 1.0, + }, + { + name: "8760 hours (1 year)", + validitySeconds: 31536000, // 365 * 24 * 3600 + want: 12.0, // 8760/730 = 12.0 + }, + { + name: "17520 hours (2 years)", + validitySeconds: 63072000, // 730 * 24 * 3600 + want: 24.0, // 17520/730 = 24.0 + }, + { + name: "87600 hours (10 years)", + validitySeconds: 315360000, // 3650 * 24 * 3600 + want: 120.0, // 87600/730 = 120.0 + }, + { + name: "90 days", + validitySeconds: 7776000, // 90 * 24 * 3600 + want: 2.9589, // 2160/730 = 2.958904... rounded to 4 decimals + }, + { + name: "365 days (1 year)", + validitySeconds: 31536000, // 365 * 24 * 3600 + want: 12.0, // 8760/730 = 12.0 + }, + { + name: "fractional result - 100 hours", + validitySeconds: 360000, // 100 * 3600 + want: 0.137, // 100/730 = 0.136986... rounded to 4 decimals + }, + { + name: "fractional result - 500 hours", + validitySeconds: 1800000, // 500 * 3600 + want: 0.6849, // 500/730 = 0.684931... rounded to 4 decimals + }, + { + name: "very small duration - 1 second", + validitySeconds: 1, + want: 0.0, // 1/3600/730 = 0.00000038... rounds to 0.0 + }, + { + name: "very small duration - 60 seconds", + validitySeconds: 60, + want: 0.0, // 60/3600/730 = 0.000023... rounds to 0.0 + }, + { + name: "very small duration - 600 seconds", + validitySeconds: 600, + want: 0.0002, // 600/3600/730 = 0.000228... rounds to 0.0002 + }, + { + name: "edge case - exactly rounds up", + validitySeconds: 13149, // Should result in value that rounds up + want: 0.0050, // 3.6525/730 = 0.005003... rounds to 0.0050 + }, + { + name: "edge case - exactly rounds down", + validitySeconds: 13140, // Should result in value that rounds down + want: 0.0050, // 3.65/730 = 0.005000 + }, + { + name: "large duration - 100 years", + validitySeconds: 3153600000, // 100 * 365 * 24 * 3600 + want: 1200.0, // 876000/730 = 1200.0 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := durationAdjustedCertificateCount(tt.validitySeconds) + if got != tt.want { + t.Errorf("durationAdjustedCertificateCount(%d) = %v, want %v", tt.validitySeconds, got, tt.want) + } + }) + } +} + +func Test_durationAdjustedCertificateCount_Precision(t *testing.T) { + // Test that the function properly rounds to 4 decimal places + tests := []struct { + name string + validitySeconds int64 + wantPrecision int // number of decimal places + }{ + { + name: "result has max 4 decimal places - case 1", + validitySeconds: 12345, + wantPrecision: 4, + }, + { + name: "result has max 4 decimal places - case 2", + validitySeconds: 98765, + wantPrecision: 4, + }, + { + name: "result has max 4 decimal places - case 3", + validitySeconds: 555555, + wantPrecision: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := durationAdjustedCertificateCount(tt.validitySeconds) + // Check that the result has at most 4 decimal places + // by multiplying by 10000 and checking if it's an integer + scaled := got * 10000 + if scaled != math.Floor(scaled) { + t.Errorf("durationAdjustedCertificateCount(%d) = %v has more than 4 decimal places", tt.validitySeconds, got) + } + }) + } +} + +func Test_durationAdjustedCertificateCount_Consistency(t *testing.T) { + // Test that the function is consistent with the public wrapper + tests := []struct { + name string + validitySeconds int64 + }{ + {"1 hour", 3600}, + {"1 day", 86400}, + {"1 month", 2628000}, + {"1 year", 31536000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the internal function + internal := durationAdjustedCertificateCount(tt.validitySeconds) + + // The internal function should produce the same result as calculating manually + validityHours := float64(tt.validitySeconds) / 3600.0 + units := validityHours / 730.0 + expected := math.Round(units*10000) / 10000 + + if internal != expected { + t.Errorf("durationAdjustedCertificateCount(%d) = %v, manual calculation = %v", tt.validitySeconds, internal, expected) + } + }) + } +} + +func Test_durationAdjustedCertificateCount_NegativeInput(t *testing.T) { + // Test behavior with negative input (edge case) + // Note: In practice, this shouldn't happen with valid certificates, + // but we should verify the function's behavior + validitySeconds := int64(-3600) + got := durationAdjustedCertificateCount(validitySeconds) + + // The function should handle negative values mathematically + // -1 hour / 730 hours = -0.0014 (rounded to 4 decimals) + want := -0.0014 + + if got != want { + t.Errorf("durationAdjustedCertificateCount(%d) = %v, want %v", validitySeconds, got, want) + } +} + +func Benchmark_durationAdjustedCertificateCount(b *testing.B) { + validitySeconds := int64(31536000) // 1 year + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = durationAdjustedCertificateCount(validitySeconds) + } +} + +func Benchmark_durationAdjustedCertificateCount_Various(b *testing.B) { + testCases := []int64{ + 3600, // 1 hour + 86400, // 1 day + 2628000, // 1 month + 31536000, // 1 year + 315360000, // 10 years + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, validitySeconds := range testCases { + _ = durationAdjustedCertificateCount(validitySeconds) + } + } +} + +// Made with Bob diff --git a/vault/pki_cert_count/pki_cert_count_manager.go b/vault/pki_cert_count/pki_cert_count_manager.go index 503ca60fec..8926d7c64d 100644 --- a/vault/pki_cert_count/pki_cert_count_manager.go +++ b/vault/pki_cert_count/pki_cert_count_manager.go @@ -4,6 +4,7 @@ package pki_cert_count import ( + "crypto/x509" "os" "sync" "time" @@ -36,7 +37,7 @@ type PkiCertificateCountManager interface { // GetCounts returns the current counts of issued and stored certificates, without // consuming them. Meant to ease unit testing. - GetCounts() (issuedCount, storedCount uint64) + GetCounts() logical.CertCount } // certCountManager is an implementation of PkiCertificateCountManager. @@ -139,10 +140,12 @@ func (m *certCountManager) Increment() logical.CertCountIncrementer { return logical.NewCertCountIncrementer(m) } -func (m *certCountManager) GetCounts() (issuedCount, storedCount uint64) { +func (m *certCountManager) GetCounts() (issuedCount logical.CertCount) { m.countLock.RLock() defer m.countLock.RUnlock() - return m.count.IssuedCerts, m.count.StoredCerts + ret := logical.CertCount{} + ret.Add(m.count) + return ret } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -164,7 +167,7 @@ func (n *nullPkiCertificateCountManager) Increment() logical.CertCountIncremente return logical.NewCertCountIncrementer(n) } -func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool) { +func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool, _ *x509.Certificate) { // nothing to do } @@ -176,6 +179,6 @@ func (n *nullPkiCertificateCountManager) StopConsumerJob() { // nothing to do } -func (n *nullPkiCertificateCountManager) GetCounts() (issuedCount, storedCount uint64) { - return 0, 0 +func (n *nullPkiCertificateCountManager) GetCounts() (issuedCount logical.CertCount) { + return logical.CertCount{} } diff --git a/vault/pki_cert_count/pki_cert_count_manager_test.go b/vault/pki_cert_count/pki_cert_count_manager_test.go index 62a9b2a889..6aee5f74d8 100644 --- a/vault/pki_cert_count/pki_cert_count_manager_test.go +++ b/vault/pki_cert_count/pki_cert_count_manager_test.go @@ -4,6 +4,11 @@ package pki_cert_count import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" "sync" "sync/atomic" "testing" @@ -14,6 +19,34 @@ import ( "github.com/stretchr/testify/require" ) +// createTestCertificate creates a test certificate with the specified validity duration +func createTestCertificate(t *testing.T, validity time.Duration) *x509.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + notBefore := time.Now() + notAfter := notBefore.Add(validity) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "test-cert", + }, + NotBefore: notBefore, + NotAfter: notAfter, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + return cert +} + // TestPkiCertificateCountManager_IncrementAndConsume tests the behaviour of // PkiCertificateCountManager. func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { @@ -41,8 +74,15 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { manager.AddCount(logical.CertCount{IssuedCerts: 3, StoredCerts: 0}) manager.AddCount(logical.CertCount{IssuedCerts: 0, StoredCerts: 5}) - manager.Increment().AddIssuedCertificate(true) - manager.Increment().AddIssuedCertificate(false) + + // Create test certificates with different validity periods + // 730 hours = 1 month = 1.0 billable unit + cert1Month := createTestCertificate(t, 730*time.Hour) + // 8760 hours = 1 year = 12.0 billable units + cert1Year := createTestCertificate(t, 8760*time.Hour) + + manager.Increment().AddIssuedCertificate(true, cert1Month) + manager.Increment().AddIssuedCertificate(false, cert1Year) time.Sleep(100 * time.Millisecond) @@ -54,5 +94,7 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { require.Equal(t, uint64(5), jobCount.IssuedCerts, "issued count mismatch") require.Equal(t, uint64(6), jobCount.StoredCerts, "stored count mismatch") + // cert1Month: 730/730 = 1.0, cert1Year: 8760/730 = 12.0, total = 13.0 + require.InDelta(t, 13.0, jobCount.PkiDurationAdjustedCerts, 0.0001, "billable units mismatch") require.Zero(t, firstConsumerTotalCount.Load(), "first consumer should not have been called") } diff --git a/vault/testing.go b/vault/testing.go index 0fc4615fe7..8d7d1fb9d2 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1073,7 +1073,7 @@ type TestClusterCore struct { UnderlyingHAStorage physical.HABackend Barrier SecurityBarrier NodeID string - pkiCertificateCountData struct{ ignoredIssuedCount, ignoredStoredCount uint64 } + pkiCertificateCountData logical.CertCount } type PhysicalBackendBundle struct { diff --git a/vault/testing_util.go b/vault/testing_util.go index 0487759737..1a1352eb9f 100644 --- a/vault/testing_util.go +++ b/vault/testing_util.go @@ -29,20 +29,21 @@ func (c *TestClusterCore) StopPkiCertificateCountConsumerJob() { func (c *TestClusterCore) ResetPkiCertificateCounts() { mgr := c.Core.pkiCertCountManager.(pki_cert_count.PkiCertificateCountManager) - c.pkiCertificateCountData.ignoredIssuedCount, c.pkiCertificateCountData.ignoredStoredCount = mgr.GetCounts() + + c.pkiCertificateCountData = mgr.GetCounts() } func (c *TestClusterCore) RequirePkiCertificateCounts(t testing.TB, expectedIssuedCount, expectedStoredCount int) { t.Helper() mgr := c.Core.pkiCertCountManager.(pki_cert_count.PkiCertificateCountManager) - actualIssuedCount, actualStoredCount := mgr.GetCounts() + actualCount := mgr.GetCounts() - actualIssuedCount -= c.pkiCertificateCountData.ignoredIssuedCount - actualStoredCount -= c.pkiCertificateCountData.ignoredStoredCount + actualCount.IssuedCerts -= c.pkiCertificateCountData.IssuedCerts + actualCount.StoredCerts -= c.pkiCertificateCountData.StoredCerts - c.pkiCertificateCountData.ignoredIssuedCount += uint64(expectedIssuedCount) - c.pkiCertificateCountData.ignoredStoredCount += uint64(expectedStoredCount) + c.pkiCertificateCountData.IssuedCerts += uint64(expectedIssuedCount) + c.pkiCertificateCountData.StoredCerts += uint64(expectedStoredCount) - require.Equal(t, expectedIssuedCount, int(actualIssuedCount), "PKI certificate issued count mismatch") - require.Equal(t, expectedStoredCount, int(actualStoredCount), "PKI certificate stored count mismatch") + require.Equal(t, expectedIssuedCount, int(actualCount.IssuedCerts), "PKI certificate issued count mismatch") + require.Equal(t, expectedStoredCount, int(actualCount.StoredCerts), "PKI certificate stored count mismatch") }