Count duration adjusted certificate counts for billing (#12286) (#12310)

* Change PkiCertificateCountManager.GetCounts() to return a CertCount.

* Add PkiDurationAdjustedCerts field to CertCount.

Add a new field to CertCount to keep track of "duration adjusted" issued
certificates.

Add an x509.Certificate argument to CertCountIncrementer.AddIssuedCertificate.
In the implementation, use the certificate's NotBefore and NotAfter fields to
calculate the validity duration for the certificate, and use that to compute the
duration adjusted units.

* Add the issued certificate to calls to AddIssuedCertificate.

* Add PkiDurationAdjustedCerts when forwarding counts.

Add pki_duration_adjusted_certificate_count to IncrementPkiCount proto.

Update replicationServiceHandler.IncrementPkiCertCountRequest to take into
account the new field.

* Run make proto.

* Update testingPkiCertificateCounter to make assertions on time adjusted counts.

* PR review: Don't use NotAfter.Sub(NotBefore), since time.Duration is max 290 years.

* PR review: Move DurationAdjustedCertificateCount to logical.pki/test_helpers.

Add Bob generated unit tests for logical.durationAdjustedCertificateCount.

* Run make fmt.

Co-authored-by: Victor Rodriguez Rizo <vrizo@hashicorp.com>
This commit is contained in:
Vault Automation 2026-02-13 09:41:37 -05:00 committed by GitHub
parent 22e5336265
commit 9cfcfec78a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 341 additions and 30 deletions

View file

@ -297,7 +297,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
if err != nil {
return nil, err
}
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, signedCertBundle.Certificate)
}
hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber)

View file

@ -484,7 +484,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
}
}
b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore, parsedBundle.Certificate)
if useCSR {
if role.UseCSRCommonName && data.Get("common_name").(string) != "" {

View file

@ -308,7 +308,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
if err != nil {
return nil, err
}
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, parsedBundle.Certificate)
// Build a fresh CRL
warnings, err = b.CrlBuilder().Rebuild(sc, true)
@ -462,7 +462,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
if err != nil {
return nil, err
}
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true, parsedBundle.Certificate)
if warnAboutTruncate &&
signingBundle.Certificate.NotAfter.Equal(parsedBundle.Certificate.NotAfter) {

View file

@ -525,11 +525,21 @@ func (c *testingPkiCertificateCounter) Increment() logical.CertCountIncrementer
return logical.NewCertCountIncrementer(c)
}
func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64) {
func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64, pkiDurationAdjustedCerts float64) {
require.Equal(t, issuedCerts, c.count.IssuedCerts, "issued certificates count mismatch %s")
require.Equal(t, storedCerts, c.count.StoredCerts, "stored certificates count mismatch %s")
require.Equal(t, pkiDurationAdjustedCerts, c.count.PkiDurationAdjustedCerts, "pki duration adjusted certificates count mismatch %s")
}
func (c *testingPkiCertificateCounter) RequireZero(t require.TestingT) {
c.RequireCount(t, 0, 0)
c.RequireCount(t, 0, 0, 0)
}
// See logical.durationAdjustedCertificateCount for the billing specification and implementation details.
func adjustedCertificateCountFromDuration(validity time.Duration) float64 {
const standardDuration = 730.0
validityHours := validity.Hours()
units := validityHours / standardDuration
// Round to 4 decimal places
return math.Round(units*10000) / 10000
}

View file

@ -3,6 +3,11 @@
package logical
import (
"crypto/x509"
"math"
)
// CertificateCounter is an interface for incrementing the count of issued and stored
// certificates.
type CertificateCounter interface {
@ -16,21 +21,43 @@ type CertificateCounter interface {
// CertCount represents the parameters for incrementing certificate counts.
type CertCount struct {
IssuedCerts uint64
IssuedCerts uint64 // TODO(victorr): Rename to PkiIssuedCerts
StoredCerts uint64
// PkiDurationAdjustedCerts tracks the normalized certificate duration units for billing
// purposes. Each certificate's billable units = (Validity Hours ÷ 730), rounded to 4 decimal
// places.
PkiDurationAdjustedCerts float64
}
func (i *CertCount) Add(other CertCount) {
i.IssuedCerts += other.IssuedCerts
i.StoredCerts += other.StoredCerts
i.PkiDurationAdjustedCerts += other.PkiDurationAdjustedCerts
}
func (i *CertCount) IsZero() bool {
return i.IssuedCerts == 0 && i.StoredCerts == 0
return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0
}
// durationAdjustedCertificateCount calculates the billable units for a certificate based on its
// validity duration. WARNING: Beware the maximum value for time.Duration (approximately 290 years).
//
// The calculation follows the billing specification:
// - Standard duration is 730 hours (1 month)
// - Units = (Validity Hours ÷ 730), rounded to 4 decimal places
// - Example: 1-year cert (8760 hours) = 12.0000 units
// - Example: 1-day cert (24 hours) = 0.0329 units
func durationAdjustedCertificateCount(validitySeconds int64) float64 {
const standardDuration = 730.0
validityHours := float64(validitySeconds) / 3600.0
units := validityHours / standardDuration
// Round to 4 decimal places
return math.Round(units*10000) / 10000
}
type CertCountIncrementer interface {
AddIssuedCertificate(stored bool) CertCountIncrementer
AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer
}
type certCountIncrementer struct {
@ -44,10 +71,16 @@ func NewCertCountIncrementer(counter CertificateCounter) CertCountIncrementer {
return &certCountIncrementer{counter: counter}
}
// AddIssuedCertificate increments the issued certificate count by 1, and also the
// stored certificate count if stored is true.
func (c *certCountIncrementer) AddIssuedCertificate(stored bool) CertCountIncrementer {
count := CertCount{IssuedCerts: 1}
// AddIssuedCertificate increments the issued certificate count by 1, the stored certificate
// count if stored is true, and adds the calculated billable units based on the certificate's
// validity duration.
// cert: The X.509 certificate to extract validity duration from.
func (c *certCountIncrementer) AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer {
validity := int64(cert.NotAfter.Unix() - cert.NotBefore.Unix())
count := CertCount{
IssuedCerts: 1,
PkiDurationAdjustedCerts: durationAdjustedCertificateCount(validity),
}
if stored {
count.StoredCerts = 1
}

View file

@ -0,0 +1,222 @@
// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: MPL-2.0
package logical
import (
"math"
"testing"
)
func Test_durationAdjustedCertificateCount(t *testing.T) {
tests := []struct {
name string
validitySeconds int64
want float64
}{
{
name: "zero duration",
validitySeconds: 0,
want: 0.0,
},
{
name: "1 hour",
validitySeconds: 3600,
want: 0.0014, // 1/730 rounded to 4 decimals
},
{
name: "24 hours (1 day)",
validitySeconds: 86400, // 24 * 3600
want: 0.0329, // 24/730 = 0.032876... rounded to 4 decimals
},
{
name: "730 hours (standard duration - 1 month)",
validitySeconds: 2628000, // 730 * 3600
want: 1.0,
},
{
name: "8760 hours (1 year)",
validitySeconds: 31536000, // 365 * 24 * 3600
want: 12.0, // 8760/730 = 12.0
},
{
name: "17520 hours (2 years)",
validitySeconds: 63072000, // 730 * 24 * 3600
want: 24.0, // 17520/730 = 24.0
},
{
name: "87600 hours (10 years)",
validitySeconds: 315360000, // 3650 * 24 * 3600
want: 120.0, // 87600/730 = 120.0
},
{
name: "90 days",
validitySeconds: 7776000, // 90 * 24 * 3600
want: 2.9589, // 2160/730 = 2.958904... rounded to 4 decimals
},
{
name: "365 days (1 year)",
validitySeconds: 31536000, // 365 * 24 * 3600
want: 12.0, // 8760/730 = 12.0
},
{
name: "fractional result - 100 hours",
validitySeconds: 360000, // 100 * 3600
want: 0.137, // 100/730 = 0.136986... rounded to 4 decimals
},
{
name: "fractional result - 500 hours",
validitySeconds: 1800000, // 500 * 3600
want: 0.6849, // 500/730 = 0.684931... rounded to 4 decimals
},
{
name: "very small duration - 1 second",
validitySeconds: 1,
want: 0.0, // 1/3600/730 = 0.00000038... rounds to 0.0
},
{
name: "very small duration - 60 seconds",
validitySeconds: 60,
want: 0.0, // 60/3600/730 = 0.000023... rounds to 0.0
},
{
name: "very small duration - 600 seconds",
validitySeconds: 600,
want: 0.0002, // 600/3600/730 = 0.000228... rounds to 0.0002
},
{
name: "edge case - exactly rounds up",
validitySeconds: 13149, // Should result in value that rounds up
want: 0.0050, // 3.6525/730 = 0.005003... rounds to 0.0050
},
{
name: "edge case - exactly rounds down",
validitySeconds: 13140, // Should result in value that rounds down
want: 0.0050, // 3.65/730 = 0.005000
},
{
name: "large duration - 100 years",
validitySeconds: 3153600000, // 100 * 365 * 24 * 3600
want: 1200.0, // 876000/730 = 1200.0
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := durationAdjustedCertificateCount(tt.validitySeconds)
if got != tt.want {
t.Errorf("durationAdjustedCertificateCount(%d) = %v, want %v", tt.validitySeconds, got, tt.want)
}
})
}
}
func Test_durationAdjustedCertificateCount_Precision(t *testing.T) {
// Test that the function properly rounds to 4 decimal places
tests := []struct {
name string
validitySeconds int64
wantPrecision int // number of decimal places
}{
{
name: "result has max 4 decimal places - case 1",
validitySeconds: 12345,
wantPrecision: 4,
},
{
name: "result has max 4 decimal places - case 2",
validitySeconds: 98765,
wantPrecision: 4,
},
{
name: "result has max 4 decimal places - case 3",
validitySeconds: 555555,
wantPrecision: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := durationAdjustedCertificateCount(tt.validitySeconds)
// Check that the result has at most 4 decimal places
// by multiplying by 10000 and checking if it's an integer
scaled := got * 10000
if scaled != math.Floor(scaled) {
t.Errorf("durationAdjustedCertificateCount(%d) = %v has more than 4 decimal places", tt.validitySeconds, got)
}
})
}
}
func Test_durationAdjustedCertificateCount_Consistency(t *testing.T) {
// Test that the function is consistent with the public wrapper
tests := []struct {
name string
validitySeconds int64
}{
{"1 hour", 3600},
{"1 day", 86400},
{"1 month", 2628000},
{"1 year", 31536000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Call the internal function
internal := durationAdjustedCertificateCount(tt.validitySeconds)
// The internal function should produce the same result as calculating manually
validityHours := float64(tt.validitySeconds) / 3600.0
units := validityHours / 730.0
expected := math.Round(units*10000) / 10000
if internal != expected {
t.Errorf("durationAdjustedCertificateCount(%d) = %v, manual calculation = %v", tt.validitySeconds, internal, expected)
}
})
}
}
func Test_durationAdjustedCertificateCount_NegativeInput(t *testing.T) {
// Test behavior with negative input (edge case)
// Note: In practice, this shouldn't happen with valid certificates,
// but we should verify the function's behavior
validitySeconds := int64(-3600)
got := durationAdjustedCertificateCount(validitySeconds)
// The function should handle negative values mathematically
// -1 hour / 730 hours = -0.0014 (rounded to 4 decimals)
want := -0.0014
if got != want {
t.Errorf("durationAdjustedCertificateCount(%d) = %v, want %v", validitySeconds, got, want)
}
}
func Benchmark_durationAdjustedCertificateCount(b *testing.B) {
validitySeconds := int64(31536000) // 1 year
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = durationAdjustedCertificateCount(validitySeconds)
}
}
func Benchmark_durationAdjustedCertificateCount_Various(b *testing.B) {
testCases := []int64{
3600, // 1 hour
86400, // 1 day
2628000, // 1 month
31536000, // 1 year
315360000, // 10 years
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, validitySeconds := range testCases {
_ = durationAdjustedCertificateCount(validitySeconds)
}
}
}
// Made with Bob

View file

@ -4,6 +4,7 @@
package pki_cert_count
import (
"crypto/x509"
"os"
"sync"
"time"
@ -36,7 +37,7 @@ type PkiCertificateCountManager interface {
// GetCounts returns the current counts of issued and stored certificates, without
// consuming them. Meant to ease unit testing.
GetCounts() (issuedCount, storedCount uint64)
GetCounts() logical.CertCount
}
// certCountManager is an implementation of PkiCertificateCountManager.
@ -139,10 +140,12 @@ func (m *certCountManager) Increment() logical.CertCountIncrementer {
return logical.NewCertCountIncrementer(m)
}
func (m *certCountManager) GetCounts() (issuedCount, storedCount uint64) {
func (m *certCountManager) GetCounts() (issuedCount logical.CertCount) {
m.countLock.RLock()
defer m.countLock.RUnlock()
return m.count.IssuedCerts, m.count.StoredCerts
ret := logical.CertCount{}
ret.Add(m.count)
return ret
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -164,7 +167,7 @@ func (n *nullPkiCertificateCountManager) Increment() logical.CertCountIncremente
return logical.NewCertCountIncrementer(n)
}
func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool) {
func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool, _ *x509.Certificate) {
// nothing to do
}
@ -176,6 +179,6 @@ func (n *nullPkiCertificateCountManager) StopConsumerJob() {
// nothing to do
}
func (n *nullPkiCertificateCountManager) GetCounts() (issuedCount, storedCount uint64) {
return 0, 0
func (n *nullPkiCertificateCountManager) GetCounts() (issuedCount logical.CertCount) {
return logical.CertCount{}
}

View file

@ -4,6 +4,11 @@
package pki_cert_count
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"sync"
"sync/atomic"
"testing"
@ -14,6 +19,34 @@ import (
"github.com/stretchr/testify/require"
)
// createTestCertificate creates a test certificate with the specified validity duration
func createTestCertificate(t *testing.T, validity time.Duration) *x509.Certificate {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
notBefore := time.Now()
notAfter := notBefore.Add(validity)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "test-cert",
},
NotBefore: notBefore,
NotAfter: notAfter,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)
return cert
}
// TestPkiCertificateCountManager_IncrementAndConsume tests the behaviour of
// PkiCertificateCountManager.
func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) {
@ -41,8 +74,15 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) {
manager.AddCount(logical.CertCount{IssuedCerts: 3, StoredCerts: 0})
manager.AddCount(logical.CertCount{IssuedCerts: 0, StoredCerts: 5})
manager.Increment().AddIssuedCertificate(true)
manager.Increment().AddIssuedCertificate(false)
// Create test certificates with different validity periods
// 730 hours = 1 month = 1.0 billable unit
cert1Month := createTestCertificate(t, 730*time.Hour)
// 8760 hours = 1 year = 12.0 billable units
cert1Year := createTestCertificate(t, 8760*time.Hour)
manager.Increment().AddIssuedCertificate(true, cert1Month)
manager.Increment().AddIssuedCertificate(false, cert1Year)
time.Sleep(100 * time.Millisecond)
@ -54,5 +94,7 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) {
require.Equal(t, uint64(5), jobCount.IssuedCerts, "issued count mismatch")
require.Equal(t, uint64(6), jobCount.StoredCerts, "stored count mismatch")
// cert1Month: 730/730 = 1.0, cert1Year: 8760/730 = 12.0, total = 13.0
require.InDelta(t, 13.0, jobCount.PkiDurationAdjustedCerts, 0.0001, "billable units mismatch")
require.Zero(t, firstConsumerTotalCount.Load(), "first consumer should not have been called")
}

View file

@ -1073,7 +1073,7 @@ type TestClusterCore struct {
UnderlyingHAStorage physical.HABackend
Barrier SecurityBarrier
NodeID string
pkiCertificateCountData struct{ ignoredIssuedCount, ignoredStoredCount uint64 }
pkiCertificateCountData logical.CertCount
}
type PhysicalBackendBundle struct {

View file

@ -29,20 +29,21 @@ func (c *TestClusterCore) StopPkiCertificateCountConsumerJob() {
func (c *TestClusterCore) ResetPkiCertificateCounts() {
mgr := c.Core.pkiCertCountManager.(pki_cert_count.PkiCertificateCountManager)
c.pkiCertificateCountData.ignoredIssuedCount, c.pkiCertificateCountData.ignoredStoredCount = mgr.GetCounts()
c.pkiCertificateCountData = mgr.GetCounts()
}
func (c *TestClusterCore) RequirePkiCertificateCounts(t testing.TB, expectedIssuedCount, expectedStoredCount int) {
t.Helper()
mgr := c.Core.pkiCertCountManager.(pki_cert_count.PkiCertificateCountManager)
actualIssuedCount, actualStoredCount := mgr.GetCounts()
actualCount := mgr.GetCounts()
actualIssuedCount -= c.pkiCertificateCountData.ignoredIssuedCount
actualStoredCount -= c.pkiCertificateCountData.ignoredStoredCount
actualCount.IssuedCerts -= c.pkiCertificateCountData.IssuedCerts
actualCount.StoredCerts -= c.pkiCertificateCountData.StoredCerts
c.pkiCertificateCountData.ignoredIssuedCount += uint64(expectedIssuedCount)
c.pkiCertificateCountData.ignoredStoredCount += uint64(expectedStoredCount)
c.pkiCertificateCountData.IssuedCerts += uint64(expectedIssuedCount)
c.pkiCertificateCountData.StoredCerts += uint64(expectedStoredCount)
require.Equal(t, expectedIssuedCount, int(actualIssuedCount), "PKI certificate issued count mismatch")
require.Equal(t, expectedStoredCount, int(actualStoredCount), "PKI certificate stored count mismatch")
require.Equal(t, expectedIssuedCount, int(actualCount.IssuedCerts), "PKI certificate issued count mismatch")
require.Equal(t, expectedStoredCount, int(actualCount.StoredCerts), "PKI certificate stored count mismatch")
}