Add counting for SSH certs and OTPs (#12368) (#12755)

* add cert counting for ssh

* add system view and fix errors

* add otp counting and change units for certs

* add storage tests

* fix census errors

* run make fmt

* use incrementer and change storage to match rfc

* run make fmt

* fix interface and remove parameter

* fix errors

* Update builtin/logical/ssh/path_creds_create.go



* remove error check

* add ssh counts to billing endpoint

* fix error

* add test case

* add ssh metric to test

* add get functions and tests

* fix format

* create function for ssh metrics

* refactoring and add test cases

* replace test check

* add ssh to billing overview test

---------

Co-authored-by: Rachel Culpepper <84159930+rculpepper@users.noreply.github.com>
Co-authored-by: Victor Rodriguez Rizo <vrizo@hashicorp.com>
This commit is contained in:
Vault Automation 2026-03-11 10:30:48 -04:00 committed by GitHub
parent 921dc42cdc
commit d34cb72e68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 630 additions and 7 deletions

View file

@ -33,7 +33,7 @@ func TestSys_BillingOverview(t *testing.T) {
currentMonth := resp.Months[0]
require.Equal(t, "2026-01", currentMonth.Month)
require.Equal(t, "2026-01-14T10:49:00Z", currentMonth.UpdatedAt)
require.Len(t, currentMonth.UsageMetrics, 8, "should have all 8 metrics")
require.Len(t, currentMonth.UsageMetrics, 9, "should have all 9 metrics")
// Create a map to verify all expected metrics are present
metricsMap := make(map[string]UsageMetric)
@ -51,6 +51,7 @@ func TestSys_BillingOverview(t *testing.T) {
"data_protection_calls",
"pki_units",
"managed_keys",
"ssh_units",
}
for _, metricName := range expectedMetrics {
@ -86,6 +87,10 @@ func TestSys_BillingOverview(t *testing.T) {
require.Equal(t, "external_plugins", externalPluginsMetric.MetricName)
require.NotNil(t, externalPluginsMetric.MetricData)
require.Contains(t, externalPluginsMetric.MetricData, "total")
sshMetric := metricsMap["ssh_units"]
require.Contains(t, sshMetric.MetricData, "total")
require.Contains(t, sshMetric.MetricData, "metric_details")
}
func mockVaultBillingHandler(w http.ResponseWriter, _ *http.Request) {
@ -200,6 +205,22 @@ const billingOverviewResponse = `{
}
]
}
},
{
"metric_name": "ssh_units",
"metric_data": {
"total": 8.4,
"metric_details": [
{
"type": "otp_units",
"count": 5
},
{
"type": "certificate_units",
"count": 3.4
}
]
}
}
]
},

View file

@ -18,10 +18,11 @@ const operationPrefixSSH = "ssh"
type backend struct {
*framework.Backend
view logical.Storage
salt *salt.Salt
saltMutex sync.RWMutex
backendUUID string
view logical.Storage
salt *salt.Salt
saltMutex sync.RWMutex
backendUUID string
sshCertificateCounter logical.CertificateCounter
}
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
@ -84,6 +85,12 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
BackendType: logical.TypeLogical,
}
if sshCertCounterSysView, ok := conf.System.(logical.CertificateCountSystemView); ok {
b.sshCertificateCounter = sshCertCounterSysView.GetCertificateCounter()
} else {
b.sshCertificateCounter = logical.NewNullCertificateCounter()
}
b.backendUUID = conf.BackendUUID
return &b, nil
}

View file

@ -203,6 +203,8 @@ func (b *backend) GenerateOTPCredential(ctx context.Context, req *logical.Reques
if err := req.Storage.Put(ctx, newEntry); err != nil {
return "", err
}
b.sshCertificateCounter.Increment().AddSSHOTP()
return otp, nil
}

View file

@ -129,6 +129,8 @@ func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logic
return nil, nil, errors.New("error marshaling signed certificate")
}
b.sshCertificateCounter.Increment().AddSSHCertificate(ttl)
response := &logical.Response{
Data: map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),

View file

@ -6,6 +6,7 @@ package logical
import (
"crypto/x509"
"math"
"time"
)
// CertificateCounter is an interface for incrementing the count of issued and stored
@ -28,16 +29,20 @@ type CertCount struct {
// purposes. Each certificate's billable units = (Validity Hours ÷ 730), rounded to 4 decimal
// places.
PkiDurationAdjustedCerts float64
SSHIssuedCerts float64
SSHIssuedOTPs uint64
}
func (i *CertCount) Add(other CertCount) {
i.IssuedCerts += other.IssuedCerts
i.StoredCerts += other.StoredCerts
i.PkiDurationAdjustedCerts += other.PkiDurationAdjustedCerts
i.SSHIssuedCerts += other.SSHIssuedCerts
i.SSHIssuedOTPs += other.SSHIssuedOTPs
}
func (i *CertCount) IsZero() bool {
return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0
return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0 && i.SSHIssuedCerts == 0 && i.SSHIssuedOTPs == 0
}
// durationAdjustedCertificateCount calculates the billable units for a certificate based on its
@ -58,6 +63,8 @@ func durationAdjustedCertificateCount(validitySeconds int64) float64 {
type CertCountIncrementer interface {
AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer
AddSSHCertificate(ttl time.Duration) CertCountIncrementer
AddSSHOTP() CertCountIncrementer
}
type certCountIncrementer struct {
@ -88,3 +95,21 @@ func (c *certCountIncrementer) AddIssuedCertificate(stored bool, cert *x509.Cert
return c
}
func (c *certCountIncrementer) AddSSHCertificate(ttl time.Duration) CertCountIncrementer {
count := CertCount{
SSHIssuedCerts: durationAdjustedCertificateCount(int64(ttl.Seconds())),
}
c.counter.AddCount(count)
return c
}
func (c *certCountIncrementer) AddSSHOTP() CertCountIncrementer {
c.counter.AddCount(CertCount{
SSHIssuedOTPs: 1,
})
return c
}

View file

@ -29,6 +29,8 @@ const (
KmipEnabledPrefix = "kmipEnabled/"
PkiDurationAdjustedCountPrefix = "normalizedCertsIssued/"
MetricsLastUpdatedAtPrefix = "metricsLastUpdatedAt/"
SSHCertificateMetric = "ssh/normalized-certs-issued"
SSHOTPMetric = "ssh/credential-count"
BillingWriteInterval = 10 * time.Minute
// pluginCountsSendTimeout is the timeout for sending plugin counts to the active node

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
)
@ -758,3 +759,169 @@ func (c *Core) UpdateMetricsLastUpdateTime(ctx context.Context, currentMonth, up
return c.storeMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, normalizedMonth, updateTime)
}
// GetStoredSSHDurationAdjustedCertCount retrieves the stored SSH duration-adjusted certificate count
// for the specified month. The count is stored as a float64.
// Returns 0 if no count has been stored for the given month.
func (c *Core) GetStoredSSHDurationAdjustedCertCount(ctx context.Context, currentMonth time.Time) (float64, error) {
c.consumptionBillingLock.RLock()
cb := c.consumptionBilling
c.consumptionBillingLock.RUnlock()
if cb == nil {
return 0, errors.New("consumption billing is not initialized")
}
cb.BillingStorageLock.RLock()
defer cb.BillingStorageLock.RUnlock()
return c.getStoredSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth)
}
func (c *Core) getStoredSSHDurationAdjustedCertCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (float64, error) {
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHCertificateMetric)
view, ok := c.GetBillingSubView()
if !ok {
return 0, errors.New("error reading SSH duration-adjusted count: billing subview not available")
}
se, err := view.Get(ctx, billingPath)
if se == nil || err != nil {
return 0, err
}
var certCount float64
err = se.DecodeJSON(&certCount)
if err != nil {
return 0, fmt.Errorf("error decoding current SSH duration adjusted cert count: %w", err)
}
return certCount, nil
}
func (c *Core) UpdateStoredSSHDurationAdjustedCertCount(ctx context.Context, currentMonth time.Time, certCount float64) (float64, error) {
c.consumptionBillingLock.RLock()
cb := c.consumptionBilling
c.consumptionBillingLock.RUnlock()
if cb == nil {
return 0, ErrConsumptionBillingNotInitialized
}
cb.BillingStorageLock.Lock()
defer cb.BillingStorageLock.Unlock()
storedCertCount, err := c.getStoredSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth)
if err != nil {
return 0, err
}
err = c.storeSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth, certCount+storedCertCount)
if err != nil {
return 0, err
}
return certCount, nil
}
func (c *Core) storeSSHDurationAdjustedCertCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, certCount float64) error {
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHCertificateMetric)
countBytes, err := jsonutil.EncodeJSON(certCount)
if err != nil {
return err
}
entry := &logical.StorageEntry{
Key: billingPath,
Value: countBytes,
}
view, ok := c.GetBillingSubView()
if !ok {
return nil
}
return view.Put(ctx, entry)
}
// GetStoredSSHOTPCount retrieves the stored SSH OTP count
// for the specified month. The count is stored as a uint64.
// Returns 0 if no count has been stored for the given month.
func (c *Core) GetStoredSSHOTPCount(ctx context.Context, currentMonth time.Time) (uint64, error) {
c.consumptionBillingLock.RLock()
cb := c.consumptionBilling
c.consumptionBillingLock.RUnlock()
if cb == nil {
return 0, errors.New("consumption billing is not initialized")
}
cb.BillingStorageLock.RLock()
defer cb.BillingStorageLock.RUnlock()
return c.getStoredSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth)
}
func (c *Core) getStoredSSHOTPCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (uint64, error) {
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHOTPMetric)
view, ok := c.GetBillingSubView()
if !ok {
return 0, errors.New("error reading SSH OTP count: billing subview not available")
}
se, err := view.Get(ctx, billingPath)
if se == nil || err != nil {
return 0, err
}
var otpCount uint64
err = se.DecodeJSON(&otpCount)
if err != nil {
return 0, fmt.Errorf("error decoding current OTP cert count: %w", err)
}
return otpCount, nil
}
func (c *Core) UpdateStoredSSHOTPCount(ctx context.Context, currentMonth time.Time, otpCount uint64) (uint64, error) {
c.consumptionBillingLock.RLock()
cb := c.consumptionBilling
c.consumptionBillingLock.RUnlock()
if cb == nil {
return 0, ErrConsumptionBillingNotInitialized
}
cb.BillingStorageLock.Lock()
defer cb.BillingStorageLock.Unlock()
storedOTPCount, err := c.getStoredSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth)
if err != nil {
return 0, err
}
err = c.storeSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth, otpCount+storedOTPCount)
if err != nil {
return 0, err
}
return otpCount, nil
}
func (c *Core) storeSSHOTPCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, otpCount uint64) error {
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHOTPMetric)
countBytes, err := jsonutil.EncodeJSON(otpCount)
if err != nil {
return err
}
entry := &logical.StorageEntry{
Key: billingPath,
Value: countBytes,
}
view, ok := c.GetBillingSubView()
if !ok {
return nil
}
return view.Put(ctx, entry)
}

View file

@ -6,6 +6,7 @@ package vault
import (
"context"
"fmt"
"math"
"testing"
"time"
@ -23,6 +24,7 @@ import (
logicalDatabase "github.com/hashicorp/vault/builtin/logical/database"
logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad"
logicalRabbitMQ "github.com/hashicorp/vault/builtin/logical/rabbitmq"
"github.com/hashicorp/vault/builtin/logical/ssh"
"github.com/hashicorp/vault/builtin/logical/totp"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/helper/namespace"
@ -809,6 +811,265 @@ func TestTransitDataProtectionCallCounts(t *testing.T) {
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0")
}
// TestSSHCertCounts tests that we correctly store and track the SSH certificate counts
func TestSSHCertCounts(t *testing.T) {
standardDuration := 730.0
validityHours := float64(60*60*24) / 3600.0
units := validityHours / standardDuration
// Round to 4 decimal places
expectedCertUnit := math.Round(units*10000) / 10000
t.Parallel()
coreConfig := &CoreConfig{
LogicalBackends: map[string]logical.Factory{
"ssh": ssh.Factory,
},
}
core, _, root := TestCoreUnsealedWithConfig(t, coreConfig)
// Mount SSH backend
req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh")
req.Data["type"] = "ssh"
req.ClientToken = root
ctx := namespace.RootContext(context.Background())
_, err := core.HandleRequest(ctx, req)
require.NoError(t, err)
// Create a certificate
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca")
req.ClientToken = root
_, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test")
req.ClientToken = root
req.Data["key_type"] = "ca"
req.Data["allow_user_certificates"] = true
req.Data["allow_empty_principals"] = true
req.Data["ttl"] = "1d"
_, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test")
req.ClientToken = root
resp, err := core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Verify that the SSH counter is incremented
require.Equal(t, expectedCertUnit, core.certCountManager.GetCounts().SSHIssuedCerts)
// Test sign endpoint
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/sign/test")
req.Data["public_key"] = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJBp4mozY/snvG/+pkgv4xYifIFB2ov3gAvAqXgFqNpj vault-enterprise-key"
req.ClientToken = root
resp, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Verify that the SSH counter is incremented
require.Equal(t, expectedCertUnit*2, core.certCountManager.GetCounts().SSHIssuedCerts)
// Now test persisting the summed counts - store and retrieve counts
// First, update the SSH cert counts (this will sum current counter with stored value)
currentCount := core.certCountManager.GetCounts().SSHIssuedCerts
core.certCountManager.StopConsumerJob()
time.Sleep(20 * time.Millisecond)
// Verify the counter was reset after update
require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update")
// Retrieve the stored counts
storedCounts, err := core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now())
require.NoError(t, err)
require.Equal(t, currentCount, storedCounts)
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
// Perform more operations to increase the counter
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test")
req.ClientToken = root
resp, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Counter should now be 1 cert
require.Equal(t, expectedCertUnit, core.certCountManager.GetCounts().SSHIssuedCerts)
// Update counts again - should sum the new count with the stored count
core.certCountManager.StopConsumerJob()
time.Sleep(20 * time.Millisecond)
// Verify the counter was reset after update
require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update")
// Verify stored counts are now the sum
summedCounts, err := core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now())
require.NoError(t, err)
expectedSum := currentCount + expectedCertUnit
require.Equal(t, expectedSum, summedCounts, "Count should be sum of stored and current")
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
// Add more operations without manually resetting
for i := 0; i < 3; i++ {
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/sign/test")
req.Data["public_key"] = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJBp4mozY/snvG/+pkgv4xYifIFB2ov3gAvAqXgFqNpj vault-enterprise-key"
req.ClientToken = root
resp, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
}
// Counter should be 3 certs
require.Equal(t, expectedCertUnit*3, core.certCountManager.GetCounts().SSHIssuedCerts)
// Update counts - should sum 3 with the previous stored sum
core.certCountManager.StopConsumerJob()
time.Sleep(20 * time.Millisecond)
// Verify the counter was reset after update
require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update")
// Verify stored counts
storedCounts, err = core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now())
require.NoError(t, err)
expectedSum += expectedCertUnit * 3
require.Equal(t, expectedSum, storedCounts)
}
// TestSSHOTPCounts tests that we correctly store and track the SSH OTP counts
func TestSSHOTPCounts(t *testing.T) {
t.Parallel()
coreConfig := &CoreConfig{
LogicalBackends: map[string]logical.Factory{
"ssh": ssh.Factory,
},
}
core, _, root := TestCoreUnsealedWithConfig(t, coreConfig)
// Mount SSH backend
req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh")
req.Data["type"] = "ssh"
req.ClientToken = root
ctx := namespace.RootContext(context.Background())
_, err := core.HandleRequest(ctx, req)
require.NoError(t, err)
// Create a certificate
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca")
req.ClientToken = root
_, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test")
req.ClientToken = root
req.Data["key_type"] = "otp"
req.Data["default_user"] = "user"
_, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/zeroaddress")
req.ClientToken = root
req.Data["roles"] = "test"
_, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
req.ClientToken = root
req.Data["ip"] = "1.2.3.4"
resp, err := core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Verify that the SSH counter is incremented
require.Equal(t, uint64(1), core.certCountManager.GetCounts().SSHIssuedOTPs)
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
req.ClientToken = root
req.Data["ip"] = "1.2.3.4"
resp, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Verify that the SSH counter is incremented
require.Equal(t, uint64(2), core.certCountManager.GetCounts().SSHIssuedOTPs)
// Now test persisting the summed counts - store and retrieve counts
// First, update the SSH cert counts (this will sum current counter with stored value)
currentCount := core.certCountManager.GetCounts().SSHIssuedOTPs
core.certCountManager.StopConsumerJob()
time.Sleep(20 * time.Millisecond)
// Verify the counter was reset after update
require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update")
// Retrieve the stored counts
storedCounts, err := core.GetStoredSSHOTPCount(ctx, time.Now())
require.NoError(t, err)
require.Equal(t, currentCount, storedCounts)
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
// Perform more operations to increase the counter
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
req.ClientToken = root
req.Data["ip"] = "1.2.3.4"
resp, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Counter should now be 1
require.Equal(t, uint64(1), core.certCountManager.GetCounts().SSHIssuedOTPs)
// Update counts again - should sum the new count with the stored count
core.certCountManager.StopConsumerJob()
time.Sleep(20 * time.Millisecond)
// Verify the counter was reset after update
require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update")
// Verify stored counts are now the sum
summedCounts, err := core.GetStoredSSHOTPCount(ctx, time.Now())
require.NoError(t, err)
expectedSum := currentCount + 1
require.Equal(t, expectedSum, summedCounts, "Count should be sum of stored and current")
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
// Add more operations without manually resetting
for i := 0; i < 3; i++ {
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
req.ClientToken = root
req.Data["ip"] = "1.2.3.4"
resp, err = core.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
}
// Counter should be 3
require.Equal(t, uint64(3), core.certCountManager.GetCounts().SSHIssuedOTPs)
// Update counts - should sum 3 with the previous stored sum
core.certCountManager.StopConsumerJob()
time.Sleep(20 * time.Millisecond)
// Verify the counter was reset after update
require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update")
// Verify stored counts
storedCounts, err = core.GetStoredSSHOTPCount(ctx, time.Now())
require.NoError(t, err)
expectedSum += 3
require.Equal(t, expectedSum, storedCounts)
}
func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) {
raw, ok := core.router.root.Get(mount + "/")
if !ok {

View file

@ -13,7 +13,10 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)
var _ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil)
var (
_ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil)
_ logical.CertificateCountSystemView = (*extendedSystemViewImpl)(nil)
)
type extendedSystemViewImpl struct {
dynamicSystemView
@ -152,3 +155,7 @@ func (e extendedSystemViewImpl) DeregisterWellKnownRedirect(ctx context.Context,
func (e extendedSystemViewImpl) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
return e.core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName)
}
func (e extendedSystemViewImpl) GetCertificateCounter() logical.CertificateCounter {
return e.core.GetCertificateCounter()
}

View file

@ -172,6 +172,7 @@ func Test_BillingOverview_EmptyCluster(t *testing.T) {
"data_protection_calls": false,
"pki_units": false,
"managed_keys": false,
"ssh_units": false,
}
for _, metric := range currentMonth.UsageMetrics {

View file

@ -332,6 +332,21 @@ func (c *Core) consumeCertCounts(inc logical.CertCount) {
unconsumed.PkiDurationAdjustedCerts = 0
}
c.logger.Info("storing SSH counts", "sshDurationAdjustedCount", inc.SSHIssuedCerts, "sshOTPCount", inc.SSHIssuedOTPs)
_, err = c.UpdateStoredSSHDurationAdjustedCertCount(c.activeContext, time.Now(), inc.SSHIssuedCerts)
if err != nil {
c.logger.Error("error storing SSH duration adjusted certificate count", "error", err)
} else {
unconsumed.SSHIssuedCerts = 0
}
_, err = c.UpdateStoredSSHOTPCount(c.activeContext, time.Now(), inc.SSHIssuedOTPs)
if err != nil {
c.logger.Error("error storing SSH OTP count", "error", err)
} else {
unconsumed.SSHIssuedOTPs = 0
}
default:
c.logger.Error("Unexpected HA state when consuming certificate counts", "ha_state", haState)
}

View file

@ -210,6 +210,12 @@ func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Ti
},
})
sshCounts, err := b.buildSSHMetric(ctx, month)
if err != nil {
return nil, err
}
usageMetrics = append(usageMetrics, sshCounts)
dataUpdatedAt := b.Core.computeUpdatedAt(ctx, month, currentMonth)
monthStr := month.Format("2006-01")
@ -486,3 +492,32 @@ func (c *Core) getThirdPartyPluginCounts(ctx context.Context, month time.Time) (
return thirdPartyPluginCounts, nil
}
func (b *SystemBackend) buildSSHMetric(ctx context.Context, month time.Time) (map[string]interface{}, error) {
certCounts, err := b.Core.GetStoredSSHDurationAdjustedCertCount(ctx, month)
if err != nil {
return nil, fmt.Errorf("error retrieving SSH duration-adjuested cert counts for current month: %w", err)
}
otpCounts, err := b.Core.GetStoredSSHOTPCount(ctx, month)
if err != nil {
return nil, fmt.Errorf("error retrieving SSH OTP counts for current month: %w", err)
}
return map[string]interface{}{
"metric_name": "ssh_units",
"metric_data": map[string]interface{}{
"total": certCounts + float64(otpCounts),
"metric_details": []map[string]interface{}{
{
"type": "otp_units",
"count": otpCounts,
},
{
"type": "certificate_units",
"count": certCounts,
},
},
},
}, nil
}

View file

@ -11,6 +11,7 @@ import (
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
logicalAws "github.com/hashicorp/vault/builtin/logical/aws"
logicalDatabase "github.com/hashicorp/vault/builtin/logical/database"
logicalSsh "github.com/hashicorp/vault/builtin/logical/ssh"
logicalTransit "github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/pluginconsts"
@ -172,6 +173,7 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
pluginconsts.SecretEngineAWS: logicalAws.Factory,
pluginconsts.SecretEngineDatabase: logicalDatabase.Factory,
pluginconsts.SecretEngineTransit: logicalTransit.Factory,
pluginconsts.SecretEngineSsh: logicalSsh.Factory,
},
})
b := c.systemBackend
@ -233,6 +235,53 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
_, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
// Create SSH certificate and OTP
req = logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh")
req.Data["type"] = "ssh"
req.ClientToken = root
resp, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca")
req.ClientToken = root
resp, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test-cert")
req.ClientToken = root
req.Data["key_type"] = "ca"
req.Data["allow_user_certificates"] = true
req.Data["allow_empty_principals"] = true
req.Data["ttl"] = "1d"
_, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test-cert")
req.ClientToken = root
resp, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test-otp")
req.ClientToken = root
req.Data["key_type"] = "otp"
req.Data["default_user"] = "user"
_, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/zeroaddress")
req.ClientToken = root
req.Data["roles"] = "test-otp"
_, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test-otp")
req.ClientToken = root
req.Data["ip"] = "1.2.3.4"
resp, err = c.HandleRequest(ctx, req)
require.NoError(t, err)
require.Nil(t, resp.Error())
// Update all metrics
currentMonth := time.Now()
_, err = c.UpdateMaxKvCounts(ctx, billing.LocalPrefix, currentMonth)
@ -244,6 +293,12 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
_, err = c.UpdateTransitCallCounts(ctx, currentMonth)
require.NoError(t, err)
_, err = c.UpdateStoredSSHDurationAdjustedCertCount(ctx, currentMonth, c.certCountManager.GetCounts().SSHIssuedCerts)
require.NoError(t, err)
_, err = c.UpdateStoredSSHOTPCount(ctx, currentMonth, c.certCountManager.GetCounts().SSHIssuedOTPs)
require.NoError(t, err)
// Make a request to the billing overview endpoint
req = logical.TestRequest(t, logical.ReadOperation, "billing/overview")
req.Data["refresh_data"] = true
@ -366,6 +421,24 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
require.True(t, ok, "managed_keys total should be int")
require.GreaterOrEqual(t, total, 0)
require.Contains(t, metricData, "metric_details")
case "ssh_units":
require.Contains(t, metricData, "total")
total, ok := metricData["total"].(float64)
require.True(t, ok, "ssh_units total should be float64")
require.GreaterOrEqual(t, total, float64(0))
require.Contains(t, metricData, "metric_details")
metricDetails, ok := metricData["metric_details"].([]map[string]interface{})
require.True(t, ok, "metric_details should be []map[string]interface{}")
require.NotEmpty(t, metricDetails)
require.Equal(t, len(metricDetails), 2)
require.Equal(t, metricDetails[0]["type"], "otp_units")
require.GreaterOrEqual(t, metricDetails[0]["count"], uint64(0))
require.Equal(t, metricDetails[1]["type"], "certificate_units")
require.GreaterOrEqual(t, metricDetails[1]["count"], float64(0))
}
}
@ -469,6 +542,7 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) {
"data_protection_calls": false,
"pki_units": false,
"managed_keys": false,
"ssh_units": false,
}
for _, metric := range usageMetrics {
@ -522,6 +596,10 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) {
details, ok := metricData["metric_details"].([]map[string]interface{})
require.True(t, ok, "%s metric_details should be array", metricName)
require.Empty(t, details, "%s metric_details should be empty when total is 0", metricName)
case "ssh_units":
total, ok := metricData["total"].(float64)
require.True(t, ok, "ssh_units total should be float64")
require.Equal(t, float64(0), total, "ssh_units total should be 0")
}
}