Refactor CertificateCounter.IncrementeCount to use a param object. (#12172) (#12271)

* Refactor CertificateCounter.IncrementeCount to use a param object.

In preparation to start collecting more information, refactor the
CertificateCounter to take a parameter object which can be later gain more
fields.

* Rework CertificateCounter to use a fluent interface.

Rename method IncrementCount to AddCount.

Remove method AddIssuedCertificate.

Add method Incrementer, which returns an implementation of the new
CertCountIncrementer.

* Add method CertCountIncrement.Add.

* Refactor PkiCertificateCountConsumer to take a CertCountIncrement.

* Fix TestPkiCertificateCountManager_IncrementAndConsume.

* Rename type CertCountIncrement to CertCount.

* Refactor ReadStoredCounts to return a CertCount value.

Co-authored-by: Victor Rodriguez Rizo <vrizo@hashicorp.com>
This commit is contained in:
Vault Automation 2026-02-10 11:55:10 -05:00 committed by GitHub
parent 521997a16f
commit 7b433e64ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 141 additions and 95 deletions

View file

@ -297,7 +297,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request,
if err != nil {
return nil, err
}
b.pkiCertificateCounter.AddIssuedCertificate(true)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true)
}
hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber)

View file

@ -484,7 +484,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d
}
}
b.pkiCertificateCounter.AddIssuedCertificate(!role.NoStore)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore)
if useCSR {
if role.UseCSRCommonName && data.Get("common_name").(string) != "" {

View file

@ -308,7 +308,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
if err != nil {
return nil, err
}
b.pkiCertificateCounter.AddIssuedCertificate(true)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true)
// Build a fresh CRL
warnings, err = b.CrlBuilder().Rebuild(sc, true)
@ -462,7 +462,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R
if err != nil {
return nil, err
}
b.pkiCertificateCounter.AddIssuedCertificate(true)
b.pkiCertificateCounter.Increment().AddIssuedCertificate(true)
if warnAboutTruncate &&
signingBundle.Certificate.NotAfter.Equal(parsedBundle.Certificate.NotAfter) {

View file

@ -508,33 +508,26 @@ func findOpenSSL() (string, string, bool) {
}
type testingPkiCertificateCounter struct {
IssuedCount uint64
StoredCount uint64
count logical.CertCount
}
var _ logical.CertificateCounter = (*testingPkiCertificateCounter)(nil)
func (c *testingPkiCertificateCounter) Reset() {
c.IssuedCount = 0
c.StoredCount = 0
c.count = logical.CertCount{}
}
func (c *testingPkiCertificateCounter) IncrementCount(issuedCerts, storedCerts uint64) {
c.IssuedCount += issuedCerts
c.StoredCount += storedCerts
func (c *testingPkiCertificateCounter) AddCount(params logical.CertCount) {
c.count.Add(params)
}
func (c *testingPkiCertificateCounter) AddIssuedCertificate(stored bool) {
if stored {
c.IncrementCount(1, 1)
} else {
c.IncrementCount(1, 0)
}
func (c *testingPkiCertificateCounter) Increment() logical.CertCountIncrementer {
return logical.NewCertCountIncrementer(c)
}
func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64) {
require.Equal(t, issuedCerts, c.IssuedCount, "issued certificates count mismatch %s")
require.Equal(t, storedCerts, c.StoredCount, "stored certificates count mismatch %s")
require.Equal(t, issuedCerts, c.count.IssuedCerts, "issued certificates count mismatch %s")
require.Equal(t, storedCerts, c.count.StoredCerts, "stored certificates count mismatch %s")
}
func (c *testingPkiCertificateCounter) RequireZero(t require.TestingT) {

View file

@ -7,9 +7,51 @@ package logical
// certificates.
type CertificateCounter interface {
// IncrementCount increments the count of issued and stored certificates.
IncrementCount(issuedCerts, storedCerts uint64)
AddCount(params CertCount)
// AddIssuedCertificate increments the issued certificate count by 1, and also the
// stored certificate count if stored is true.
AddIssuedCertificate(stored bool)
// Increment returns a CertCountIncrementer that can be used to add
// to the count.
Increment() CertCountIncrementer
}
// CertCount represents the parameters for incrementing certificate counts.
type CertCount struct {
IssuedCerts uint64
StoredCerts uint64
}
func (i *CertCount) Add(other CertCount) {
i.IssuedCerts += other.IssuedCerts
i.StoredCerts += other.StoredCerts
}
func (i *CertCount) IsZero() bool {
return i.IssuedCerts == 0 && i.StoredCerts == 0
}
type CertCountIncrementer interface {
AddIssuedCertificate(stored bool) CertCountIncrementer
}
type certCountIncrementer struct {
counter CertificateCounter
}
var _ CertCountIncrementer = (*certCountIncrementer)(nil)
// NewCertCountIncrementer creates a new CertCountIncrementer for the given counter.
func NewCertCountIncrementer(counter CertificateCounter) CertCountIncrementer {
return &certCountIncrementer{counter: counter}
}
// AddIssuedCertificate increments the issued certificate count by 1, and also the
// stored certificate count if stored is true.
func (c *certCountIncrementer) AddIssuedCertificate(stored bool) CertCountIncrementer {
count := CertCount{IssuedCerts: 1}
if stored {
count.StoredCerts = 1
}
c.counter.AddCount(count)
return c
}

View file

@ -9,10 +9,11 @@ type PkiCertificateCountSystemView interface {
type nullPkiCertificateCounter struct{}
func (n *nullPkiCertificateCounter) IncrementCount(_, _ uint64) {
func (n *nullPkiCertificateCounter) AddCount(_ CertCount) {
}
func (n *nullPkiCertificateCounter) AddIssuedCertificate(_ bool) {
func (n *nullPkiCertificateCounter) Increment() CertCountIncrementer {
return NewCertCountIncrementer(n)
}
var _ CertificateCounter = (*nullPkiCertificateCounter)(nil)

View file

@ -291,8 +291,8 @@ func ceSysInitialize(b *SystemBackend) func(context.Context, *logical.Initializa
return fmt.Errorf("failed to initialize activation flags: %w", err)
}
b.Core.pkiCertCountManager.StartConsumerJob(func(issuedCount, storedCount uint64) {
b.Core.consumePkiCertCounts(issuedCount, storedCount)
b.Core.pkiCertCountManager.StartConsumerJob(func(increment logical.CertCount) {
b.Core.consumePkiCertCounts(increment)
})
return nil
}
@ -300,10 +300,10 @@ func ceSysInitialize(b *SystemBackend) func(context.Context, *logical.Initializa
// consumePkiCertCounts updates the PKI certificate counts in storage if we are
// running on the active node; otherwise it forwards them to the active node.
func (c *Core) consumePkiCertCounts(issuedCount uint64, storedCount uint64) {
func (c *Core) consumePkiCertCounts(inc logical.CertCount) {
var consumed bool
haState := c.HAStateWithLock()
if issuedCount == 0 && storedCount == 0 {
if inc.IsZero() {
return
}
@ -311,10 +311,10 @@ func (c *Core) consumePkiCertCounts(issuedCount uint64, storedCount uint64) {
case consts.Standby:
consumed = true
case consts.PerfStandby:
consumed = forwardPkiCertCounts(c, issuedCount, storedCount)
consumed = forwardPkiCertCounts(c, inc.IssuedCerts, inc.StoredCerts)
case consts.Active:
c.logger.Info("storing PKI certificate counts", "issuedCerts", issuedCount, "storedCerts", storedCount)
err := pki_cert_count.IncrementStoredCounts(c.activeContext, c.barrier, issuedCount, storedCount)
c.logger.Info("storing PKI certificate counts", "issuedCerts", inc.IssuedCerts, "storedCerts", inc.StoredCerts)
err := pki_cert_count.IncrementStoredCounts(c.activeContext, c.barrier, inc)
if err != nil {
c.logger.Error("error storing PKI certificate counts", "error", err)
} else {
@ -324,7 +324,7 @@ func (c *Core) consumePkiCertCounts(issuedCount uint64, storedCount uint64) {
c.logger.Error("Unexpected HA state when consuming PKI certificate counts", "ha_state", haState)
}
if !consumed {
c.pkiCertCountManager.IncrementCount(issuedCount, storedCount)
c.pkiCertCountManager.AddCount(inc)
}
}

View file

@ -6,7 +6,6 @@ package pki_cert_count
import (
"os"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-hclog"
@ -22,7 +21,7 @@ const envVaultDisableCertCount = "VAULT_DISABLE_CERT_COUNT"
var consumerJobInterval = 1 * time.Minute
// PkiCertificateCountConsumer is a callback for consumers of the PKI certificate counts.
type PkiCertificateCountConsumer func(issuedCount, storedCount uint64)
type PkiCertificateCountConsumer func(logical.CertCount)
// PkiCertificateCountManager keeps track of issued and stored PKI certificate counts.
type PkiCertificateCountManager interface {
@ -42,8 +41,8 @@ type PkiCertificateCountManager interface {
// certCountManager is an implementation of PkiCertificateCountManager.
type certCountManager struct {
issuedCount *atomic.Uint64
storedCount *atomic.Uint64
count logical.CertCount
countLock sync.RWMutex
reportTimerStop chan struct{}
reportTimerStopLock sync.Mutex
@ -66,8 +65,7 @@ func InitPkiCertificateCountManager(logger hclog.Logger) PkiCertificateCountMana
func newPkiCertificateCountManager(logger hclog.Logger) PkiCertificateCountManager {
ret := &certCountManager{
issuedCount: &atomic.Uint64{},
storedCount: &atomic.Uint64{},
count: logical.CertCount{},
reportTimerStop: nil,
logger: logger,
}
@ -102,9 +100,13 @@ func (m *certCountManager) reportLoop(stop chan struct{}, consumer PkiCertificat
}
func (m *certCountManager) consumeCount(consumer PkiCertificateCountConsumer) {
issuedCount := m.issuedCount.Swap(0)
storedCount := m.storedCount.Swap(0)
consumer(issuedCount, storedCount)
m.countLock.Lock()
defer m.countLock.Unlock()
increment := m.count
m.count = logical.CertCount{}
consumer(increment)
}
func (m *certCountManager) StopConsumerJob() {
@ -124,22 +126,23 @@ func (m *certCountManager) stopConsumerJobWithLock() {
}
}
func (m *certCountManager) AddIssuedCertificate(stored bool) {
if stored {
m.IncrementCount(1, 1)
} else {
m.IncrementCount(1, 0)
}
func (m *certCountManager) AddCount(params logical.CertCount) {
m.countLock.Lock()
defer m.countLock.Unlock()
m.count.Add(params)
m.logger.Trace("incremented in-memory PKI certificate counts", "issuedCerts", m.count.IssuedCerts, "storedCerts", m.count.StoredCerts)
}
func (m *certCountManager) IncrementCount(issuedCerts, storedCerts uint64) {
issued := m.issuedCount.Add(issuedCerts)
stored := m.storedCount.Add(storedCerts)
m.logger.Trace("incremented in-memory PKI certificate counts", "issuedCerts", issued, "storedCerts", stored)
func (m *certCountManager) Increment() logical.CertCountIncrementer {
return logical.NewCertCountIncrementer(m)
}
func (m *certCountManager) GetCounts() (issuedCount, storedCount uint64) {
return m.issuedCount.Load(), m.storedCount.Load()
m.countLock.RLock()
defer m.countLock.RUnlock()
return m.count.IssuedCerts, m.count.StoredCerts
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -153,10 +156,14 @@ func newNullPkiCertificateCountManager() PkiCertificateCountManager {
return &nullPkiCertificateCountManager{}
}
func (n *nullPkiCertificateCountManager) IncrementCount(_, _ uint64) {
func (n *nullPkiCertificateCountManager) AddCount(_ logical.CertCount) {
// nothing to do
}
func (n *nullPkiCertificateCountManager) Increment() logical.CertCountIncrementer {
return logical.NewCertCountIncrementer(n)
}
func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool) {
// nothing to do
}

View file

@ -4,11 +4,13 @@
package pki_cert_count
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
@ -19,16 +21,16 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) {
consumerJobInterval = 10 * time.Millisecond
firstConsumerTotalCount := &atomic.Uint64{}
manager.StartConsumerJob(func(i, s uint64) {
firstConsumerTotalCount.Add(i + s)
manager.StartConsumerJob(func(inc logical.CertCount) {
firstConsumerTotalCount.Add(inc.IssuedCerts + inc.StoredCerts)
})
issued := &atomic.Uint64{}
stored := &atomic.Uint64{}
consumer := func(i, s uint64) {
issued.Add(i)
stored.Add(s)
jobCountLock := sync.Mutex{}
jobCount := logical.CertCount{}
consumer := func(inc logical.CertCount) {
jobCountLock.Lock()
defer jobCountLock.Unlock()
jobCount.Add(inc)
}
manager.StartConsumerJob(consumer)
@ -37,17 +39,17 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) {
time.Sleep(20 * time.Millisecond)
firstConsumerTotalCount.Store(0)
manager.IncrementCount(3, 0)
manager.IncrementCount(0, 5)
manager.AddIssuedCertificate(true)
manager.AddIssuedCertificate(false)
manager.AddCount(logical.CertCount{IssuedCerts: 3, StoredCerts: 0})
manager.AddCount(logical.CertCount{IssuedCerts: 0, StoredCerts: 5})
manager.Increment().AddIssuedCertificate(true)
manager.Increment().AddIssuedCertificate(false)
time.Sleep(100 * time.Millisecond)
// Calling stop again should not panic.
manager.StopConsumerJob()
require.Equal(t, uint64(5), issued.Load(), "issued count mismatch")
require.Equal(t, uint64(6), stored.Load(), "stored count mismatch")
require.Equal(t, uint64(5), jobCount.IssuedCerts, "issued count mismatch")
require.Equal(t, uint64(6), jobCount.StoredCerts, "stored count mismatch")
require.Zero(t, firstConsumerTotalCount.Load(), "first consumer should not have been called")
}

View file

@ -25,7 +25,7 @@ type PkiCertificateCount struct {
StoredCertificateCountsByDay []uint64 `json:"storedCertificateCountsByDay"`
}
func IncrementStoredCounts(ctx context.Context, storage logical.Storage, issuedCount, storedCount uint64) error {
func IncrementStoredCounts(ctx context.Context, storage logical.Storage, inc logical.CertCount) error {
year, month, day := time.Now().Date()
storagePath := getStoragePath(year, month)
@ -47,8 +47,8 @@ func IncrementStoredCounts(ctx context.Context, storage logical.Storage, issuedC
currentMonthCounts.StoredCertificateCountsByDay = make([]uint64, daysInMonth+1)
}
currentMonthCounts.IssuedCertificateCountsByDay[day] += issuedCount
currentMonthCounts.StoredCertificateCountsByDay[day] += storedCount
currentMonthCounts.IssuedCertificateCountsByDay[day] += inc.IssuedCerts
currentMonthCounts.StoredCertificateCountsByDay[day] += inc.StoredCerts
countBytes, err := json.Marshal(currentMonthCounts)
if err != nil {
@ -68,33 +68,31 @@ func IncrementStoredCounts(ctx context.Context, storage logical.Storage, issuedC
return nil
}
func ReadStoredCounts(ctx context.Context, storage logical.Storage, date time.Time) (issuedCount uint64, storedCount uint64, err error) {
issuedCount = 0
storedCount = 0
func ReadStoredCounts(ctx context.Context, storage logical.Storage, date time.Time) (count logical.CertCount, err error) {
year, month, _ := date.Date()
entry, err := storage.Get(ctx, getStoragePath(year, month))
if err != nil {
return 0, 0, fmt.Errorf("error reading from storage: %w", err)
return logical.CertCount{}, fmt.Errorf("error reading from storage: %w", err)
}
if entry == nil {
return 0, 0, fmt.Errorf("certificate counts not found for %d-%02d", year, month)
return logical.CertCount{}, fmt.Errorf("certificate counts not found for %d-%02d", year, month)
}
var certificateCounts PkiCertificateCount
err = json.Unmarshal(entry.Value, &certificateCounts)
if err != nil {
return 0, 0, fmt.Errorf("error unmarshalling certificate counts from storage: %w", err)
return logical.CertCount{}, fmt.Errorf("error unmarshalling certificate counts from storage: %w", err)
}
ret := logical.CertCount{}
for i := range certificateCounts.IssuedCertificateCountsByDay {
issuedCount += certificateCounts.IssuedCertificateCountsByDay[i]
storedCount += certificateCounts.StoredCertificateCountsByDay[i]
ret.IssuedCerts += certificateCounts.IssuedCertificateCountsByDay[i]
ret.StoredCerts += certificateCounts.StoredCertificateCountsByDay[i]
}
return issuedCount, storedCount, nil
return ret, nil
}
func getStoragePath(year int, month time.Month) string {

View file

@ -53,13 +53,13 @@ func TestGetCertificateCount(t *testing.T) {
for name, tt := range testCases {
t.Run(name, func(t *testing.T) {
issuedCount, storedCount, err := ReadStoredCounts(context.Background(), storage, tt.date)
count, err := ReadStoredCounts(context.Background(), storage, tt.date)
if tt.expectErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedIssued, issuedCount)
require.Equal(t, tt.expectedStored, storedCount)
require.Equal(t, tt.expectedIssued, count.IssuedCerts)
require.Equal(t, tt.expectedStored, count.StoredCerts)
}
})
}
@ -71,17 +71,19 @@ func TestStoreCertificateCounts(t *testing.T) {
storage := logical.NewLogicalStorage(backend)
var expectedIssuedCount uint64 = 5
var expectedStoredCount uint64 = 3
expectedCount := logical.CertCount{
IssuedCerts: 5,
StoredCerts: 3,
}
err = IncrementStoredCounts(context.Background(), storage, expectedIssuedCount, expectedStoredCount)
err = IncrementStoredCounts(context.Background(), storage, expectedCount)
require.NoError(t, err)
year, month, day := time.Now().Date()
counts := retrieveCertificateCountsFromStorage(t, storage, year, month)
require.Equal(t, expectedIssuedCount, counts.IssuedCertificateCountsByDay[day])
require.Equal(t, expectedStoredCount, counts.StoredCertificateCountsByDay[day])
require.Equal(t, expectedCount.IssuedCerts, counts.IssuedCertificateCountsByDay[day])
require.Equal(t, expectedCount.StoredCerts, counts.StoredCertificateCountsByDay[day])
}
func TestReadAfterStore(t *testing.T) {
@ -90,16 +92,17 @@ func TestReadAfterStore(t *testing.T) {
storage := logical.NewLogicalStorage(backend)
var expectedIssuedCount uint64 = 5
var expectedStoredCount uint64 = 3
expectedCount := logical.CertCount{
IssuedCerts: 5,
StoredCerts: 3,
}
err = IncrementStoredCounts(context.Background(), storage, expectedIssuedCount, expectedStoredCount)
err = IncrementStoredCounts(context.Background(), storage, expectedCount)
require.NoError(t, err)
issued, stored, err := ReadStoredCounts(context.Background(), storage, time.Now())
actual, err := ReadStoredCounts(context.Background(), storage, time.Now())
require.NoError(t, err)
require.Equal(t, expectedIssuedCount, issued)
require.Equal(t, expectedStoredCount, stored)
require.Equal(t, expectedCount, actual)
}
func retrieveCertificateCountsFromStorage(t *testing.T, storage logical.Storage, year int, month time.Month) *PkiCertificateCount {