diff --git a/builtin/logical/pki/path_acme_order.go b/builtin/logical/pki/path_acme_order.go index 53c8912442..089589c81f 100644 --- a/builtin/logical/pki/path_acme_order.go +++ b/builtin/logical/pki/path_acme_order.go @@ -297,7 +297,7 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, r *logical.Request, if err != nil { return nil, err } - b.pkiCertificateCounter.AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) } hyphenSerialNumber := normalizeSerialFromBigInt(signedCertBundle.Certificate.SerialNumber) diff --git a/builtin/logical/pki/path_issue_sign.go b/builtin/logical/pki/path_issue_sign.go index 5fe6d7bc0a..7d1e61e494 100644 --- a/builtin/logical/pki/path_issue_sign.go +++ b/builtin/logical/pki/path_issue_sign.go @@ -484,7 +484,7 @@ func (b *backend) pathIssueSignCert(ctx context.Context, req *logical.Request, d } } - b.pkiCertificateCounter.AddIssuedCertificate(!role.NoStore) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(!role.NoStore) if useCSR { if role.UseCSRCommonName && data.Get("common_name").(string) != "" { diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index 173ca059b0..2b4940374e 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -308,7 +308,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, if err != nil { return nil, err } - b.pkiCertificateCounter.AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) // Build a fresh CRL warnings, err = b.CrlBuilder().Rebuild(sc, true) @@ -462,7 +462,7 @@ func (b *backend) pathIssuerSignIntermediate(ctx context.Context, req *logical.R if err != nil { return nil, err } - b.pkiCertificateCounter.AddIssuedCertificate(true) + b.pkiCertificateCounter.Increment().AddIssuedCertificate(true) if warnAboutTruncate && signingBundle.Certificate.NotAfter.Equal(parsedBundle.Certificate.NotAfter) { diff --git a/builtin/logical/pki/test_helpers.go b/builtin/logical/pki/test_helpers.go index 672f5d3ccb..781f51a048 100644 --- a/builtin/logical/pki/test_helpers.go +++ b/builtin/logical/pki/test_helpers.go @@ -508,33 +508,26 @@ func findOpenSSL() (string, string, bool) { } type testingPkiCertificateCounter struct { - IssuedCount uint64 - StoredCount uint64 + count logical.CertCount } var _ logical.CertificateCounter = (*testingPkiCertificateCounter)(nil) func (c *testingPkiCertificateCounter) Reset() { - c.IssuedCount = 0 - c.StoredCount = 0 + c.count = logical.CertCount{} } -func (c *testingPkiCertificateCounter) IncrementCount(issuedCerts, storedCerts uint64) { - c.IssuedCount += issuedCerts - c.StoredCount += storedCerts +func (c *testingPkiCertificateCounter) AddCount(params logical.CertCount) { + c.count.Add(params) } -func (c *testingPkiCertificateCounter) AddIssuedCertificate(stored bool) { - if stored { - c.IncrementCount(1, 1) - } else { - c.IncrementCount(1, 0) - } +func (c *testingPkiCertificateCounter) Increment() logical.CertCountIncrementer { + return logical.NewCertCountIncrementer(c) } func (c *testingPkiCertificateCounter) RequireCount(t require.TestingT, issuedCerts, storedCerts uint64) { - require.Equal(t, issuedCerts, c.IssuedCount, "issued certificates count mismatch %s") - require.Equal(t, storedCerts, c.StoredCount, "stored certificates count mismatch %s") + require.Equal(t, issuedCerts, c.count.IssuedCerts, "issued certificates count mismatch %s") + require.Equal(t, storedCerts, c.count.StoredCerts, "stored certificates count mismatch %s") } func (c *testingPkiCertificateCounter) RequireZero(t require.TestingT) { diff --git a/sdk/logical/certificate_counter.go b/sdk/logical/certificate_counter.go index dbf49e46b4..4633a40ce6 100644 --- a/sdk/logical/certificate_counter.go +++ b/sdk/logical/certificate_counter.go @@ -7,9 +7,51 @@ package logical // certificates. type CertificateCounter interface { // IncrementCount increments the count of issued and stored certificates. - IncrementCount(issuedCerts, storedCerts uint64) + AddCount(params CertCount) - // AddIssuedCertificate increments the issued certificate count by 1, and also the - // stored certificate count if stored is true. - AddIssuedCertificate(stored bool) + // Increment returns a CertCountIncrementer that can be used to add + // to the count. + Increment() CertCountIncrementer +} + +// CertCount represents the parameters for incrementing certificate counts. +type CertCount struct { + IssuedCerts uint64 + StoredCerts uint64 +} + +func (i *CertCount) Add(other CertCount) { + i.IssuedCerts += other.IssuedCerts + i.StoredCerts += other.StoredCerts +} + +func (i *CertCount) IsZero() bool { + return i.IssuedCerts == 0 && i.StoredCerts == 0 +} + +type CertCountIncrementer interface { + AddIssuedCertificate(stored bool) CertCountIncrementer +} + +type certCountIncrementer struct { + counter CertificateCounter +} + +var _ CertCountIncrementer = (*certCountIncrementer)(nil) + +// NewCertCountIncrementer creates a new CertCountIncrementer for the given counter. +func NewCertCountIncrementer(counter CertificateCounter) CertCountIncrementer { + return &certCountIncrementer{counter: counter} +} + +// AddIssuedCertificate increments the issued certificate count by 1, and also the +// stored certificate count if stored is true. +func (c *certCountIncrementer) AddIssuedCertificate(stored bool) CertCountIncrementer { + count := CertCount{IssuedCerts: 1} + if stored { + count.StoredCerts = 1 + } + c.counter.AddCount(count) + + return c } diff --git a/sdk/logical/pki_cert_count_system_view.go b/sdk/logical/pki_cert_count_system_view.go index 41a9aac9f8..429a1329bc 100644 --- a/sdk/logical/pki_cert_count_system_view.go +++ b/sdk/logical/pki_cert_count_system_view.go @@ -9,10 +9,11 @@ type PkiCertificateCountSystemView interface { type nullPkiCertificateCounter struct{} -func (n *nullPkiCertificateCounter) IncrementCount(_, _ uint64) { +func (n *nullPkiCertificateCounter) AddCount(_ CertCount) { } -func (n *nullPkiCertificateCounter) AddIssuedCertificate(_ bool) { +func (n *nullPkiCertificateCounter) Increment() CertCountIncrementer { + return NewCertCountIncrementer(n) } var _ CertificateCounter = (*nullPkiCertificateCounter)(nil) diff --git a/vault/logical_system_helpers.go b/vault/logical_system_helpers.go index f663564fff..946979add2 100644 --- a/vault/logical_system_helpers.go +++ b/vault/logical_system_helpers.go @@ -291,8 +291,8 @@ func ceSysInitialize(b *SystemBackend) func(context.Context, *logical.Initializa return fmt.Errorf("failed to initialize activation flags: %w", err) } - b.Core.pkiCertCountManager.StartConsumerJob(func(issuedCount, storedCount uint64) { - b.Core.consumePkiCertCounts(issuedCount, storedCount) + b.Core.pkiCertCountManager.StartConsumerJob(func(increment logical.CertCount) { + b.Core.consumePkiCertCounts(increment) }) return nil } @@ -300,10 +300,10 @@ func ceSysInitialize(b *SystemBackend) func(context.Context, *logical.Initializa // consumePkiCertCounts updates the PKI certificate counts in storage if we are // running on the active node; otherwise it forwards them to the active node. -func (c *Core) consumePkiCertCounts(issuedCount uint64, storedCount uint64) { +func (c *Core) consumePkiCertCounts(inc logical.CertCount) { var consumed bool haState := c.HAStateWithLock() - if issuedCount == 0 && storedCount == 0 { + if inc.IsZero() { return } @@ -311,10 +311,10 @@ func (c *Core) consumePkiCertCounts(issuedCount uint64, storedCount uint64) { case consts.Standby: consumed = true case consts.PerfStandby: - consumed = forwardPkiCertCounts(c, issuedCount, storedCount) + consumed = forwardPkiCertCounts(c, inc.IssuedCerts, inc.StoredCerts) case consts.Active: - c.logger.Info("storing PKI certificate counts", "issuedCerts", issuedCount, "storedCerts", storedCount) - err := pki_cert_count.IncrementStoredCounts(c.activeContext, c.barrier, issuedCount, storedCount) + c.logger.Info("storing PKI certificate counts", "issuedCerts", inc.IssuedCerts, "storedCerts", inc.StoredCerts) + err := pki_cert_count.IncrementStoredCounts(c.activeContext, c.barrier, inc) if err != nil { c.logger.Error("error storing PKI certificate counts", "error", err) } else { @@ -324,7 +324,7 @@ func (c *Core) consumePkiCertCounts(issuedCount uint64, storedCount uint64) { c.logger.Error("Unexpected HA state when consuming PKI certificate counts", "ha_state", haState) } if !consumed { - c.pkiCertCountManager.IncrementCount(issuedCount, storedCount) + c.pkiCertCountManager.AddCount(inc) } } diff --git a/vault/pki_cert_count/pki_cert_count_manager.go b/vault/pki_cert_count/pki_cert_count_manager.go index bcd48d39a3..503ca60fec 100644 --- a/vault/pki_cert_count/pki_cert_count_manager.go +++ b/vault/pki_cert_count/pki_cert_count_manager.go @@ -6,7 +6,6 @@ package pki_cert_count import ( "os" "sync" - "sync/atomic" "time" "github.com/hashicorp/go-hclog" @@ -22,7 +21,7 @@ const envVaultDisableCertCount = "VAULT_DISABLE_CERT_COUNT" var consumerJobInterval = 1 * time.Minute // PkiCertificateCountConsumer is a callback for consumers of the PKI certificate counts. -type PkiCertificateCountConsumer func(issuedCount, storedCount uint64) +type PkiCertificateCountConsumer func(logical.CertCount) // PkiCertificateCountManager keeps track of issued and stored PKI certificate counts. type PkiCertificateCountManager interface { @@ -42,8 +41,8 @@ type PkiCertificateCountManager interface { // certCountManager is an implementation of PkiCertificateCountManager. type certCountManager struct { - issuedCount *atomic.Uint64 - storedCount *atomic.Uint64 + count logical.CertCount + countLock sync.RWMutex reportTimerStop chan struct{} reportTimerStopLock sync.Mutex @@ -66,8 +65,7 @@ func InitPkiCertificateCountManager(logger hclog.Logger) PkiCertificateCountMana func newPkiCertificateCountManager(logger hclog.Logger) PkiCertificateCountManager { ret := &certCountManager{ - issuedCount: &atomic.Uint64{}, - storedCount: &atomic.Uint64{}, + count: logical.CertCount{}, reportTimerStop: nil, logger: logger, } @@ -102,9 +100,13 @@ func (m *certCountManager) reportLoop(stop chan struct{}, consumer PkiCertificat } func (m *certCountManager) consumeCount(consumer PkiCertificateCountConsumer) { - issuedCount := m.issuedCount.Swap(0) - storedCount := m.storedCount.Swap(0) - consumer(issuedCount, storedCount) + m.countLock.Lock() + defer m.countLock.Unlock() + + increment := m.count + m.count = logical.CertCount{} + + consumer(increment) } func (m *certCountManager) StopConsumerJob() { @@ -124,22 +126,23 @@ func (m *certCountManager) stopConsumerJobWithLock() { } } -func (m *certCountManager) AddIssuedCertificate(stored bool) { - if stored { - m.IncrementCount(1, 1) - } else { - m.IncrementCount(1, 0) - } +func (m *certCountManager) AddCount(params logical.CertCount) { + m.countLock.Lock() + defer m.countLock.Unlock() + + m.count.Add(params) + + m.logger.Trace("incremented in-memory PKI certificate counts", "issuedCerts", m.count.IssuedCerts, "storedCerts", m.count.StoredCerts) } -func (m *certCountManager) IncrementCount(issuedCerts, storedCerts uint64) { - issued := m.issuedCount.Add(issuedCerts) - stored := m.storedCount.Add(storedCerts) - m.logger.Trace("incremented in-memory PKI certificate counts", "issuedCerts", issued, "storedCerts", stored) +func (m *certCountManager) Increment() logical.CertCountIncrementer { + return logical.NewCertCountIncrementer(m) } func (m *certCountManager) GetCounts() (issuedCount, storedCount uint64) { - return m.issuedCount.Load(), m.storedCount.Load() + m.countLock.RLock() + defer m.countLock.RUnlock() + return m.count.IssuedCerts, m.count.StoredCerts } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -153,10 +156,14 @@ func newNullPkiCertificateCountManager() PkiCertificateCountManager { return &nullPkiCertificateCountManager{} } -func (n *nullPkiCertificateCountManager) IncrementCount(_, _ uint64) { +func (n *nullPkiCertificateCountManager) AddCount(_ logical.CertCount) { // nothing to do } +func (n *nullPkiCertificateCountManager) Increment() logical.CertCountIncrementer { + return logical.NewCertCountIncrementer(n) +} + func (n *nullPkiCertificateCountManager) AddIssuedCertificate(_ bool) { // nothing to do } diff --git a/vault/pki_cert_count/pki_cert_count_manager_test.go b/vault/pki_cert_count/pki_cert_count_manager_test.go index 37a69031ca..ef469976ea 100644 --- a/vault/pki_cert_count/pki_cert_count_manager_test.go +++ b/vault/pki_cert_count/pki_cert_count_manager_test.go @@ -4,11 +4,13 @@ package pki_cert_count import ( + "sync" "sync/atomic" "testing" "time" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/logical" "github.com/stretchr/testify/require" ) @@ -19,16 +21,16 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { consumerJobInterval = 10 * time.Millisecond firstConsumerTotalCount := &atomic.Uint64{} - manager.StartConsumerJob(func(i, s uint64) { - firstConsumerTotalCount.Add(i + s) + manager.StartConsumerJob(func(inc logical.CertCount) { + firstConsumerTotalCount.Add(inc.IssuedCerts + inc.StoredCerts) }) - issued := &atomic.Uint64{} - stored := &atomic.Uint64{} - - consumer := func(i, s uint64) { - issued.Add(i) - stored.Add(s) + jobCountLock := sync.Mutex{} + jobCount := logical.CertCount{} + consumer := func(inc logical.CertCount) { + jobCountLock.Lock() + defer jobCountLock.Unlock() + jobCount.Add(inc) } manager.StartConsumerJob(consumer) @@ -37,17 +39,17 @@ func TestPkiCertificateCountManager_IncrementAndConsume(t *testing.T) { time.Sleep(20 * time.Millisecond) firstConsumerTotalCount.Store(0) - manager.IncrementCount(3, 0) - manager.IncrementCount(0, 5) - manager.AddIssuedCertificate(true) - manager.AddIssuedCertificate(false) + manager.AddCount(logical.CertCount{IssuedCerts: 3, StoredCerts: 0}) + manager.AddCount(logical.CertCount{IssuedCerts: 0, StoredCerts: 5}) + manager.Increment().AddIssuedCertificate(true) + manager.Increment().AddIssuedCertificate(false) time.Sleep(100 * time.Millisecond) // Calling stop again should not panic. manager.StopConsumerJob() - require.Equal(t, uint64(5), issued.Load(), "issued count mismatch") - require.Equal(t, uint64(6), stored.Load(), "stored count mismatch") + require.Equal(t, uint64(5), jobCount.IssuedCerts, "issued count mismatch") + require.Equal(t, uint64(6), jobCount.StoredCerts, "stored count mismatch") require.Zero(t, firstConsumerTotalCount.Load(), "first consumer should not have been called") } diff --git a/vault/pki_cert_count/pki_cert_count_storage.go b/vault/pki_cert_count/pki_cert_count_storage.go index 7ce3dcc155..db367888a6 100644 --- a/vault/pki_cert_count/pki_cert_count_storage.go +++ b/vault/pki_cert_count/pki_cert_count_storage.go @@ -25,7 +25,7 @@ type PkiCertificateCount struct { StoredCertificateCountsByDay []uint64 `json:"storedCertificateCountsByDay"` } -func IncrementStoredCounts(ctx context.Context, storage logical.Storage, issuedCount, storedCount uint64) error { +func IncrementStoredCounts(ctx context.Context, storage logical.Storage, inc logical.CertCount) error { year, month, day := time.Now().Date() storagePath := getStoragePath(year, month) @@ -47,8 +47,8 @@ func IncrementStoredCounts(ctx context.Context, storage logical.Storage, issuedC currentMonthCounts.StoredCertificateCountsByDay = make([]uint64, daysInMonth+1) } - currentMonthCounts.IssuedCertificateCountsByDay[day] += issuedCount - currentMonthCounts.StoredCertificateCountsByDay[day] += storedCount + currentMonthCounts.IssuedCertificateCountsByDay[day] += inc.IssuedCerts + currentMonthCounts.StoredCertificateCountsByDay[day] += inc.StoredCerts countBytes, err := json.Marshal(currentMonthCounts) if err != nil { @@ -68,33 +68,31 @@ func IncrementStoredCounts(ctx context.Context, storage logical.Storage, issuedC return nil } -func ReadStoredCounts(ctx context.Context, storage logical.Storage, date time.Time) (issuedCount uint64, storedCount uint64, err error) { - issuedCount = 0 - storedCount = 0 - +func ReadStoredCounts(ctx context.Context, storage logical.Storage, date time.Time) (count logical.CertCount, err error) { year, month, _ := date.Date() entry, err := storage.Get(ctx, getStoragePath(year, month)) if err != nil { - return 0, 0, fmt.Errorf("error reading from storage: %w", err) + return logical.CertCount{}, fmt.Errorf("error reading from storage: %w", err) } if entry == nil { - return 0, 0, fmt.Errorf("certificate counts not found for %d-%02d", year, month) + return logical.CertCount{}, fmt.Errorf("certificate counts not found for %d-%02d", year, month) } var certificateCounts PkiCertificateCount err = json.Unmarshal(entry.Value, &certificateCounts) if err != nil { - return 0, 0, fmt.Errorf("error unmarshalling certificate counts from storage: %w", err) + return logical.CertCount{}, fmt.Errorf("error unmarshalling certificate counts from storage: %w", err) } + ret := logical.CertCount{} for i := range certificateCounts.IssuedCertificateCountsByDay { - issuedCount += certificateCounts.IssuedCertificateCountsByDay[i] - storedCount += certificateCounts.StoredCertificateCountsByDay[i] + ret.IssuedCerts += certificateCounts.IssuedCertificateCountsByDay[i] + ret.StoredCerts += certificateCounts.StoredCertificateCountsByDay[i] } - return issuedCount, storedCount, nil + return ret, nil } func getStoragePath(year int, month time.Month) string { diff --git a/vault/pki_cert_count/pki_cert_count_storage_test.go b/vault/pki_cert_count/pki_cert_count_storage_test.go index 1935628324..8a5770ef01 100644 --- a/vault/pki_cert_count/pki_cert_count_storage_test.go +++ b/vault/pki_cert_count/pki_cert_count_storage_test.go @@ -53,13 +53,13 @@ func TestGetCertificateCount(t *testing.T) { for name, tt := range testCases { t.Run(name, func(t *testing.T) { - issuedCount, storedCount, err := ReadStoredCounts(context.Background(), storage, tt.date) + count, err := ReadStoredCounts(context.Background(), storage, tt.date) if tt.expectErr { require.Error(t, err) } else { require.NoError(t, err) - require.Equal(t, tt.expectedIssued, issuedCount) - require.Equal(t, tt.expectedStored, storedCount) + require.Equal(t, tt.expectedIssued, count.IssuedCerts) + require.Equal(t, tt.expectedStored, count.StoredCerts) } }) } @@ -71,17 +71,19 @@ func TestStoreCertificateCounts(t *testing.T) { storage := logical.NewLogicalStorage(backend) - var expectedIssuedCount uint64 = 5 - var expectedStoredCount uint64 = 3 + expectedCount := logical.CertCount{ + IssuedCerts: 5, + StoredCerts: 3, + } - err = IncrementStoredCounts(context.Background(), storage, expectedIssuedCount, expectedStoredCount) + err = IncrementStoredCounts(context.Background(), storage, expectedCount) require.NoError(t, err) year, month, day := time.Now().Date() counts := retrieveCertificateCountsFromStorage(t, storage, year, month) - require.Equal(t, expectedIssuedCount, counts.IssuedCertificateCountsByDay[day]) - require.Equal(t, expectedStoredCount, counts.StoredCertificateCountsByDay[day]) + require.Equal(t, expectedCount.IssuedCerts, counts.IssuedCertificateCountsByDay[day]) + require.Equal(t, expectedCount.StoredCerts, counts.StoredCertificateCountsByDay[day]) } func TestReadAfterStore(t *testing.T) { @@ -90,16 +92,17 @@ func TestReadAfterStore(t *testing.T) { storage := logical.NewLogicalStorage(backend) - var expectedIssuedCount uint64 = 5 - var expectedStoredCount uint64 = 3 + expectedCount := logical.CertCount{ + IssuedCerts: 5, + StoredCerts: 3, + } - err = IncrementStoredCounts(context.Background(), storage, expectedIssuedCount, expectedStoredCount) + err = IncrementStoredCounts(context.Background(), storage, expectedCount) require.NoError(t, err) - issued, stored, err := ReadStoredCounts(context.Background(), storage, time.Now()) + actual, err := ReadStoredCounts(context.Background(), storage, time.Now()) require.NoError(t, err) - require.Equal(t, expectedIssuedCount, issued) - require.Equal(t, expectedStoredCount, stored) + require.Equal(t, expectedCount, actual) } func retrieveCertificateCountsFromStorage(t *testing.T, storage logical.Storage, year int, month time.Month) *PkiCertificateCount {