mirror of
https://github.com/hashicorp/vault.git
synced 2026-06-08 16:24:51 -04:00
* add cert counting for ssh * add system view and fix errors * add otp counting and change units for certs * add storage tests * fix census errors * run make fmt * use incrementer and change storage to match rfc * run make fmt * fix interface and remove parameter * fix errors * Update builtin/logical/ssh/path_creds_create.go * remove error check * add ssh counts to billing endpoint * fix error * add test case * add ssh metric to test * add get functions and tests * fix format * create function for ssh metrics * refactoring and add test cases * replace test check * add ssh to billing overview test --------- Co-authored-by: Rachel Culpepper <84159930+rculpepper@users.noreply.github.com> Co-authored-by: Victor Rodriguez Rizo <vrizo@hashicorp.com>
This commit is contained in:
parent
921dc42cdc
commit
d34cb72e68
13 changed files with 630 additions and 7 deletions
|
|
@ -33,7 +33,7 @@ func TestSys_BillingOverview(t *testing.T) {
|
|||
currentMonth := resp.Months[0]
|
||||
require.Equal(t, "2026-01", currentMonth.Month)
|
||||
require.Equal(t, "2026-01-14T10:49:00Z", currentMonth.UpdatedAt)
|
||||
require.Len(t, currentMonth.UsageMetrics, 8, "should have all 8 metrics")
|
||||
require.Len(t, currentMonth.UsageMetrics, 9, "should have all 9 metrics")
|
||||
|
||||
// Create a map to verify all expected metrics are present
|
||||
metricsMap := make(map[string]UsageMetric)
|
||||
|
|
@ -51,6 +51,7 @@ func TestSys_BillingOverview(t *testing.T) {
|
|||
"data_protection_calls",
|
||||
"pki_units",
|
||||
"managed_keys",
|
||||
"ssh_units",
|
||||
}
|
||||
|
||||
for _, metricName := range expectedMetrics {
|
||||
|
|
@ -86,6 +87,10 @@ func TestSys_BillingOverview(t *testing.T) {
|
|||
require.Equal(t, "external_plugins", externalPluginsMetric.MetricName)
|
||||
require.NotNil(t, externalPluginsMetric.MetricData)
|
||||
require.Contains(t, externalPluginsMetric.MetricData, "total")
|
||||
|
||||
sshMetric := metricsMap["ssh_units"]
|
||||
require.Contains(t, sshMetric.MetricData, "total")
|
||||
require.Contains(t, sshMetric.MetricData, "metric_details")
|
||||
}
|
||||
|
||||
func mockVaultBillingHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
|
|
@ -200,6 +205,22 @@ const billingOverviewResponse = `{
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"metric_name": "ssh_units",
|
||||
"metric_data": {
|
||||
"total": 8.4,
|
||||
"metric_details": [
|
||||
{
|
||||
"type": "otp_units",
|
||||
"count": 5
|
||||
},
|
||||
{
|
||||
"type": "certificate_units",
|
||||
"count": 3.4
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -18,10 +18,11 @@ const operationPrefixSSH = "ssh"
|
|||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
view logical.Storage
|
||||
salt *salt.Salt
|
||||
saltMutex sync.RWMutex
|
||||
backendUUID string
|
||||
view logical.Storage
|
||||
salt *salt.Salt
|
||||
saltMutex sync.RWMutex
|
||||
backendUUID string
|
||||
sshCertificateCounter logical.CertificateCounter
|
||||
}
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
|
|
@ -84,6 +85,12 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
if sshCertCounterSysView, ok := conf.System.(logical.CertificateCountSystemView); ok {
|
||||
b.sshCertificateCounter = sshCertCounterSysView.GetCertificateCounter()
|
||||
} else {
|
||||
b.sshCertificateCounter = logical.NewNullCertificateCounter()
|
||||
}
|
||||
|
||||
b.backendUUID = conf.BackendUUID
|
||||
return &b, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -203,6 +203,8 @@ func (b *backend) GenerateOTPCredential(ctx context.Context, req *logical.Reques
|
|||
if err := req.Storage.Put(ctx, newEntry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
b.sshCertificateCounter.Increment().AddSSHOTP()
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,8 @@ func (b *backend) pathSignIssueCertificateHelper(ctx context.Context, req *logic
|
|||
return nil, nil, errors.New("error marshaling signed certificate")
|
||||
}
|
||||
|
||||
b.sshCertificateCounter.Increment().AddSSHCertificate(ttl)
|
||||
|
||||
response := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"serial_number": strconv.FormatUint(certificate.Serial, 16),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package logical
|
|||
import (
|
||||
"crypto/x509"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertificateCounter is an interface for incrementing the count of issued and stored
|
||||
|
|
@ -28,16 +29,20 @@ type CertCount struct {
|
|||
// purposes. Each certificate's billable units = (Validity Hours ÷ 730), rounded to 4 decimal
|
||||
// places.
|
||||
PkiDurationAdjustedCerts float64
|
||||
SSHIssuedCerts float64
|
||||
SSHIssuedOTPs uint64
|
||||
}
|
||||
|
||||
func (i *CertCount) Add(other CertCount) {
|
||||
i.IssuedCerts += other.IssuedCerts
|
||||
i.StoredCerts += other.StoredCerts
|
||||
i.PkiDurationAdjustedCerts += other.PkiDurationAdjustedCerts
|
||||
i.SSHIssuedCerts += other.SSHIssuedCerts
|
||||
i.SSHIssuedOTPs += other.SSHIssuedOTPs
|
||||
}
|
||||
|
||||
func (i *CertCount) IsZero() bool {
|
||||
return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0
|
||||
return i.IssuedCerts == 0 && i.StoredCerts == 0 && i.PkiDurationAdjustedCerts == 0 && i.SSHIssuedCerts == 0 && i.SSHIssuedOTPs == 0
|
||||
}
|
||||
|
||||
// durationAdjustedCertificateCount calculates the billable units for a certificate based on its
|
||||
|
|
@ -58,6 +63,8 @@ func durationAdjustedCertificateCount(validitySeconds int64) float64 {
|
|||
|
||||
type CertCountIncrementer interface {
|
||||
AddIssuedCertificate(stored bool, cert *x509.Certificate) CertCountIncrementer
|
||||
AddSSHCertificate(ttl time.Duration) CertCountIncrementer
|
||||
AddSSHOTP() CertCountIncrementer
|
||||
}
|
||||
|
||||
type certCountIncrementer struct {
|
||||
|
|
@ -88,3 +95,21 @@ func (c *certCountIncrementer) AddIssuedCertificate(stored bool, cert *x509.Cert
|
|||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *certCountIncrementer) AddSSHCertificate(ttl time.Duration) CertCountIncrementer {
|
||||
count := CertCount{
|
||||
SSHIssuedCerts: durationAdjustedCertificateCount(int64(ttl.Seconds())),
|
||||
}
|
||||
|
||||
c.counter.AddCount(count)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *certCountIncrementer) AddSSHOTP() CertCountIncrementer {
|
||||
c.counter.AddCount(CertCount{
|
||||
SSHIssuedOTPs: 1,
|
||||
})
|
||||
|
||||
return c
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ const (
|
|||
KmipEnabledPrefix = "kmipEnabled/"
|
||||
PkiDurationAdjustedCountPrefix = "normalizedCertsIssued/"
|
||||
MetricsLastUpdatedAtPrefix = "metricsLastUpdatedAt/"
|
||||
SSHCertificateMetric = "ssh/normalized-certs-issued"
|
||||
SSHOTPMetric = "ssh/credential-count"
|
||||
|
||||
BillingWriteInterval = 10 * time.Minute
|
||||
// pluginCountsSendTimeout is the timeout for sending plugin counts to the active node
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/timeutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
)
|
||||
|
|
@ -758,3 +759,169 @@ func (c *Core) UpdateMetricsLastUpdateTime(ctx context.Context, currentMonth, up
|
|||
|
||||
return c.storeMetricsLastUpdateTimeLocked(ctx, billing.LocalPrefix, normalizedMonth, updateTime)
|
||||
}
|
||||
|
||||
// GetStoredSSHDurationAdjustedCertCount retrieves the stored SSH duration-adjusted certificate count
|
||||
// for the specified month. The count is stored as a float64.
|
||||
// Returns 0 if no count has been stored for the given month.
|
||||
func (c *Core) GetStoredSSHDurationAdjustedCertCount(ctx context.Context, currentMonth time.Time) (float64, error) {
|
||||
c.consumptionBillingLock.RLock()
|
||||
cb := c.consumptionBilling
|
||||
c.consumptionBillingLock.RUnlock()
|
||||
|
||||
if cb == nil {
|
||||
return 0, errors.New("consumption billing is not initialized")
|
||||
}
|
||||
|
||||
cb.BillingStorageLock.RLock()
|
||||
defer cb.BillingStorageLock.RUnlock()
|
||||
|
||||
return c.getStoredSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth)
|
||||
}
|
||||
|
||||
func (c *Core) getStoredSSHDurationAdjustedCertCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (float64, error) {
|
||||
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHCertificateMetric)
|
||||
|
||||
view, ok := c.GetBillingSubView()
|
||||
if !ok {
|
||||
return 0, errors.New("error reading SSH duration-adjusted count: billing subview not available")
|
||||
}
|
||||
|
||||
se, err := view.Get(ctx, billingPath)
|
||||
if se == nil || err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var certCount float64
|
||||
err = se.DecodeJSON(&certCount)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error decoding current SSH duration adjusted cert count: %w", err)
|
||||
}
|
||||
|
||||
return certCount, nil
|
||||
}
|
||||
|
||||
func (c *Core) UpdateStoredSSHDurationAdjustedCertCount(ctx context.Context, currentMonth time.Time, certCount float64) (float64, error) {
|
||||
c.consumptionBillingLock.RLock()
|
||||
cb := c.consumptionBilling
|
||||
c.consumptionBillingLock.RUnlock()
|
||||
|
||||
if cb == nil {
|
||||
return 0, ErrConsumptionBillingNotInitialized
|
||||
}
|
||||
cb.BillingStorageLock.Lock()
|
||||
defer cb.BillingStorageLock.Unlock()
|
||||
storedCertCount, err := c.getStoredSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = c.storeSSHDurationAdjustedCertCountLocked(ctx, billing.LocalPrefix, currentMonth, certCount+storedCertCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return certCount, nil
|
||||
}
|
||||
|
||||
func (c *Core) storeSSHDurationAdjustedCertCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, certCount float64) error {
|
||||
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHCertificateMetric)
|
||||
|
||||
countBytes, err := jsonutil.EncodeJSON(certCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entry := &logical.StorageEntry{
|
||||
Key: billingPath,
|
||||
Value: countBytes,
|
||||
}
|
||||
|
||||
view, ok := c.GetBillingSubView()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return view.Put(ctx, entry)
|
||||
}
|
||||
|
||||
// GetStoredSSHOTPCount retrieves the stored SSH OTP count
|
||||
// for the specified month. The count is stored as a uint64.
|
||||
// Returns 0 if no count has been stored for the given month.
|
||||
func (c *Core) GetStoredSSHOTPCount(ctx context.Context, currentMonth time.Time) (uint64, error) {
|
||||
c.consumptionBillingLock.RLock()
|
||||
cb := c.consumptionBilling
|
||||
c.consumptionBillingLock.RUnlock()
|
||||
|
||||
if cb == nil {
|
||||
return 0, errors.New("consumption billing is not initialized")
|
||||
}
|
||||
|
||||
cb.BillingStorageLock.RLock()
|
||||
defer cb.BillingStorageLock.RUnlock()
|
||||
|
||||
return c.getStoredSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth)
|
||||
}
|
||||
|
||||
func (c *Core) getStoredSSHOTPCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time) (uint64, error) {
|
||||
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHOTPMetric)
|
||||
|
||||
view, ok := c.GetBillingSubView()
|
||||
if !ok {
|
||||
return 0, errors.New("error reading SSH OTP count: billing subview not available")
|
||||
}
|
||||
|
||||
se, err := view.Get(ctx, billingPath)
|
||||
if se == nil || err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var otpCount uint64
|
||||
err = se.DecodeJSON(&otpCount)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error decoding current OTP cert count: %w", err)
|
||||
}
|
||||
|
||||
return otpCount, nil
|
||||
}
|
||||
|
||||
func (c *Core) UpdateStoredSSHOTPCount(ctx context.Context, currentMonth time.Time, otpCount uint64) (uint64, error) {
|
||||
c.consumptionBillingLock.RLock()
|
||||
cb := c.consumptionBilling
|
||||
c.consumptionBillingLock.RUnlock()
|
||||
|
||||
if cb == nil {
|
||||
return 0, ErrConsumptionBillingNotInitialized
|
||||
}
|
||||
cb.BillingStorageLock.Lock()
|
||||
defer cb.BillingStorageLock.Unlock()
|
||||
storedOTPCount, err := c.getStoredSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = c.storeSSHOTPCountLocked(ctx, billing.LocalPrefix, currentMonth, otpCount+storedOTPCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return otpCount, nil
|
||||
}
|
||||
|
||||
func (c *Core) storeSSHOTPCountLocked(ctx context.Context, localPathPrefix string, currentMonth time.Time, otpCount uint64) error {
|
||||
billingPath := billing.GetMonthlyBillingMetricPath(localPathPrefix, currentMonth, billing.SSHOTPMetric)
|
||||
|
||||
countBytes, err := jsonutil.EncodeJSON(otpCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entry := &logical.StorageEntry{
|
||||
Key: billingPath,
|
||||
Value: countBytes,
|
||||
}
|
||||
|
||||
view, ok := c.GetBillingSubView()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return view.Put(ctx, entry)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package vault
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -23,6 +24,7 @@ import (
|
|||
logicalDatabase "github.com/hashicorp/vault/builtin/logical/database"
|
||||
logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad"
|
||||
logicalRabbitMQ "github.com/hashicorp/vault/builtin/logical/rabbitmq"
|
||||
"github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
"github.com/hashicorp/vault/builtin/logical/totp"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
|
@ -809,6 +811,265 @@ func TestTransitDataProtectionCallCounts(t *testing.T) {
|
|||
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0")
|
||||
}
|
||||
|
||||
// TestSSHCertCounts tests that we correctly store and track the SSH certificate counts
|
||||
func TestSSHCertCounts(t *testing.T) {
|
||||
standardDuration := 730.0
|
||||
validityHours := float64(60*60*24) / 3600.0
|
||||
units := validityHours / standardDuration
|
||||
// Round to 4 decimal places
|
||||
expectedCertUnit := math.Round(units*10000) / 10000
|
||||
|
||||
t.Parallel()
|
||||
coreConfig := &CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"ssh": ssh.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
core, _, root := TestCoreUnsealedWithConfig(t, coreConfig)
|
||||
|
||||
// Mount SSH backend
|
||||
req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh")
|
||||
req.Data["type"] = "ssh"
|
||||
req.ClientToken = root
|
||||
ctx := namespace.RootContext(context.Background())
|
||||
_, err := core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a certificate
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca")
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test")
|
||||
req.ClientToken = root
|
||||
req.Data["key_type"] = "ca"
|
||||
req.Data["allow_user_certificates"] = true
|
||||
req.Data["allow_empty_principals"] = true
|
||||
req.Data["ttl"] = "1d"
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test")
|
||||
req.ClientToken = root
|
||||
resp, err := core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Verify that the SSH counter is incremented
|
||||
require.Equal(t, expectedCertUnit, core.certCountManager.GetCounts().SSHIssuedCerts)
|
||||
|
||||
// Test sign endpoint
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/sign/test")
|
||||
req.Data["public_key"] = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJBp4mozY/snvG/+pkgv4xYifIFB2ov3gAvAqXgFqNpj vault-enterprise-key"
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Verify that the SSH counter is incremented
|
||||
require.Equal(t, expectedCertUnit*2, core.certCountManager.GetCounts().SSHIssuedCerts)
|
||||
|
||||
// Now test persisting the summed counts - store and retrieve counts
|
||||
// First, update the SSH cert counts (this will sum current counter with stored value)
|
||||
currentCount := core.certCountManager.GetCounts().SSHIssuedCerts
|
||||
core.certCountManager.StopConsumerJob()
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update")
|
||||
|
||||
// Retrieve the stored counts
|
||||
storedCounts, err := core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentCount, storedCounts)
|
||||
|
||||
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
|
||||
|
||||
// Perform more operations to increase the counter
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test")
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Counter should now be 1 cert
|
||||
require.Equal(t, expectedCertUnit, core.certCountManager.GetCounts().SSHIssuedCerts)
|
||||
|
||||
// Update counts again - should sum the new count with the stored count
|
||||
core.certCountManager.StopConsumerJob()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update")
|
||||
|
||||
// Verify stored counts are now the sum
|
||||
summedCounts, err := core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedSum := currentCount + expectedCertUnit
|
||||
require.Equal(t, expectedSum, summedCounts, "Count should be sum of stored and current")
|
||||
|
||||
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
|
||||
|
||||
// Add more operations without manually resetting
|
||||
for i := 0; i < 3; i++ {
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/sign/test")
|
||||
req.Data["public_key"] = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJBp4mozY/snvG/+pkgv4xYifIFB2ov3gAvAqXgFqNpj vault-enterprise-key"
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
}
|
||||
|
||||
// Counter should be 3 certs
|
||||
require.Equal(t, expectedCertUnit*3, core.certCountManager.GetCounts().SSHIssuedCerts)
|
||||
|
||||
// Update counts - should sum 3 with the previous stored sum
|
||||
core.certCountManager.StopConsumerJob()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, float64(0), core.certCountManager.GetCounts().SSHIssuedCerts, "Counter should be reset after update")
|
||||
|
||||
// Verify stored counts
|
||||
storedCounts, err = core.GetStoredSSHDurationAdjustedCertCount(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
expectedSum += expectedCertUnit * 3
|
||||
require.Equal(t, expectedSum, storedCounts)
|
||||
}
|
||||
|
||||
// TestSSHOTPCounts tests that we correctly store and track the SSH OTP counts
|
||||
func TestSSHOTPCounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
coreConfig := &CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"ssh": ssh.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
core, _, root := TestCoreUnsealedWithConfig(t, coreConfig)
|
||||
|
||||
// Mount SSH backend
|
||||
req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh")
|
||||
req.Data["type"] = "ssh"
|
||||
req.ClientToken = root
|
||||
ctx := namespace.RootContext(context.Background())
|
||||
_, err := core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a certificate
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca")
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test")
|
||||
req.ClientToken = root
|
||||
req.Data["key_type"] = "otp"
|
||||
req.Data["default_user"] = "user"
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/zeroaddress")
|
||||
req.ClientToken = root
|
||||
req.Data["roles"] = "test"
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
|
||||
req.ClientToken = root
|
||||
req.Data["ip"] = "1.2.3.4"
|
||||
resp, err := core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Verify that the SSH counter is incremented
|
||||
require.Equal(t, uint64(1), core.certCountManager.GetCounts().SSHIssuedOTPs)
|
||||
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
|
||||
req.ClientToken = root
|
||||
req.Data["ip"] = "1.2.3.4"
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Verify that the SSH counter is incremented
|
||||
require.Equal(t, uint64(2), core.certCountManager.GetCounts().SSHIssuedOTPs)
|
||||
|
||||
// Now test persisting the summed counts - store and retrieve counts
|
||||
// First, update the SSH cert counts (this will sum current counter with stored value)
|
||||
currentCount := core.certCountManager.GetCounts().SSHIssuedOTPs
|
||||
core.certCountManager.StopConsumerJob()
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update")
|
||||
|
||||
// Retrieve the stored counts
|
||||
storedCounts, err := core.GetStoredSSHOTPCount(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentCount, storedCounts)
|
||||
|
||||
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
|
||||
|
||||
// Perform more operations to increase the counter
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
|
||||
req.ClientToken = root
|
||||
req.Data["ip"] = "1.2.3.4"
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Counter should now be 1
|
||||
require.Equal(t, uint64(1), core.certCountManager.GetCounts().SSHIssuedOTPs)
|
||||
|
||||
// Update counts again - should sum the new count with the stored count
|
||||
core.certCountManager.StopConsumerJob()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update")
|
||||
|
||||
// Verify stored counts are now the sum
|
||||
summedCounts, err := core.GetStoredSSHOTPCount(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedSum := currentCount + 1
|
||||
require.Equal(t, expectedSum, summedCounts, "Count should be sum of stored and current")
|
||||
|
||||
core.certCountManager.StartConsumerJob(core.consumeCertCounts)
|
||||
|
||||
// Add more operations without manually resetting
|
||||
for i := 0; i < 3; i++ {
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test")
|
||||
req.ClientToken = root
|
||||
req.Data["ip"] = "1.2.3.4"
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
}
|
||||
|
||||
// Counter should be 3
|
||||
require.Equal(t, uint64(3), core.certCountManager.GetCounts().SSHIssuedOTPs)
|
||||
|
||||
// Update counts - should sum 3 with the previous stored sum
|
||||
core.certCountManager.StopConsumerJob()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, uint64(0), core.certCountManager.GetCounts().SSHIssuedOTPs, "Counter should be reset after update")
|
||||
|
||||
// Verify stored counts
|
||||
storedCounts, err = core.GetStoredSSHOTPCount(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
expectedSum += 3
|
||||
require.Equal(t, expectedSum, storedCounts)
|
||||
}
|
||||
|
||||
func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) {
|
||||
raw, ok := core.router.root.Get(mount + "/")
|
||||
if !ok {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
var _ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil)
|
||||
var (
|
||||
_ logical.ExtendedSystemView = (*extendedSystemViewImpl)(nil)
|
||||
_ logical.CertificateCountSystemView = (*extendedSystemViewImpl)(nil)
|
||||
)
|
||||
|
||||
type extendedSystemViewImpl struct {
|
||||
dynamicSystemView
|
||||
|
|
@ -152,3 +155,7 @@ func (e extendedSystemViewImpl) DeregisterWellKnownRedirect(ctx context.Context,
|
|||
func (e extendedSystemViewImpl) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
|
||||
return e.core.pluginCatalog.GetPinnedVersion(ctx, pluginType, pluginName)
|
||||
}
|
||||
|
||||
func (e extendedSystemViewImpl) GetCertificateCounter() logical.CertificateCounter {
|
||||
return e.core.GetCertificateCounter()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@ func Test_BillingOverview_EmptyCluster(t *testing.T) {
|
|||
"data_protection_calls": false,
|
||||
"pki_units": false,
|
||||
"managed_keys": false,
|
||||
"ssh_units": false,
|
||||
}
|
||||
|
||||
for _, metric := range currentMonth.UsageMetrics {
|
||||
|
|
|
|||
|
|
@ -332,6 +332,21 @@ func (c *Core) consumeCertCounts(inc logical.CertCount) {
|
|||
unconsumed.PkiDurationAdjustedCerts = 0
|
||||
}
|
||||
|
||||
c.logger.Info("storing SSH counts", "sshDurationAdjustedCount", inc.SSHIssuedCerts, "sshOTPCount", inc.SSHIssuedOTPs)
|
||||
_, err = c.UpdateStoredSSHDurationAdjustedCertCount(c.activeContext, time.Now(), inc.SSHIssuedCerts)
|
||||
if err != nil {
|
||||
c.logger.Error("error storing SSH duration adjusted certificate count", "error", err)
|
||||
} else {
|
||||
unconsumed.SSHIssuedCerts = 0
|
||||
}
|
||||
|
||||
_, err = c.UpdateStoredSSHOTPCount(c.activeContext, time.Now(), inc.SSHIssuedOTPs)
|
||||
if err != nil {
|
||||
c.logger.Error("error storing SSH OTP count", "error", err)
|
||||
} else {
|
||||
unconsumed.SSHIssuedOTPs = 0
|
||||
}
|
||||
|
||||
default:
|
||||
c.logger.Error("Unexpected HA state when consuming certificate counts", "ha_state", haState)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -210,6 +210,12 @@ func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Ti
|
|||
},
|
||||
})
|
||||
|
||||
sshCounts, err := b.buildSSHMetric(ctx, month)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usageMetrics = append(usageMetrics, sshCounts)
|
||||
|
||||
dataUpdatedAt := b.Core.computeUpdatedAt(ctx, month, currentMonth)
|
||||
|
||||
monthStr := month.Format("2006-01")
|
||||
|
|
@ -486,3 +492,32 @@ func (c *Core) getThirdPartyPluginCounts(ctx context.Context, month time.Time) (
|
|||
|
||||
return thirdPartyPluginCounts, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) buildSSHMetric(ctx context.Context, month time.Time) (map[string]interface{}, error) {
|
||||
certCounts, err := b.Core.GetStoredSSHDurationAdjustedCertCount(ctx, month)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving SSH duration-adjuested cert counts for current month: %w", err)
|
||||
}
|
||||
|
||||
otpCounts, err := b.Core.GetStoredSSHOTPCount(ctx, month)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving SSH OTP counts for current month: %w", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"metric_name": "ssh_units",
|
||||
"metric_data": map[string]interface{}{
|
||||
"total": certCounts + float64(otpCounts),
|
||||
"metric_details": []map[string]interface{}{
|
||||
{
|
||||
"type": "otp_units",
|
||||
"count": otpCounts,
|
||||
},
|
||||
{
|
||||
"type": "certificate_units",
|
||||
"count": certCounts,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
|
||||
logicalAws "github.com/hashicorp/vault/builtin/logical/aws"
|
||||
logicalDatabase "github.com/hashicorp/vault/builtin/logical/database"
|
||||
logicalSsh "github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
logicalTransit "github.com/hashicorp/vault/builtin/logical/transit"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/helper/pluginconsts"
|
||||
|
|
@ -172,6 +173,7 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
|
|||
pluginconsts.SecretEngineAWS: logicalAws.Factory,
|
||||
pluginconsts.SecretEngineDatabase: logicalDatabase.Factory,
|
||||
pluginconsts.SecretEngineTransit: logicalTransit.Factory,
|
||||
pluginconsts.SecretEngineSsh: logicalSsh.Factory,
|
||||
},
|
||||
})
|
||||
b := c.systemBackend
|
||||
|
|
@ -233,6 +235,53 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
|
|||
_, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create SSH certificate and OTP
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "sys/mounts/ssh")
|
||||
req.Data["type"] = "ssh"
|
||||
req.ClientToken = root
|
||||
resp, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/ca")
|
||||
req.ClientToken = root
|
||||
resp, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test-cert")
|
||||
req.ClientToken = root
|
||||
req.Data["key_type"] = "ca"
|
||||
req.Data["allow_user_certificates"] = true
|
||||
req.Data["allow_empty_principals"] = true
|
||||
req.Data["ttl"] = "1d"
|
||||
_, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/issue/test-cert")
|
||||
req.ClientToken = root
|
||||
resp, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/roles/test-otp")
|
||||
req.ClientToken = root
|
||||
req.Data["key_type"] = "otp"
|
||||
req.Data["default_user"] = "user"
|
||||
_, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "ssh/config/zeroaddress")
|
||||
req.ClientToken = root
|
||||
req.Data["roles"] = "test-otp"
|
||||
_, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "ssh/creds/test-otp")
|
||||
req.ClientToken = root
|
||||
req.Data["ip"] = "1.2.3.4"
|
||||
resp, err = c.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.Error())
|
||||
|
||||
// Update all metrics
|
||||
currentMonth := time.Now()
|
||||
_, err = c.UpdateMaxKvCounts(ctx, billing.LocalPrefix, currentMonth)
|
||||
|
|
@ -244,6 +293,12 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
|
|||
_, err = c.UpdateTransitCallCounts(ctx, currentMonth)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = c.UpdateStoredSSHDurationAdjustedCertCount(ctx, currentMonth, c.certCountManager.GetCounts().SSHIssuedCerts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = c.UpdateStoredSSHOTPCount(ctx, currentMonth, c.certCountManager.GetCounts().SSHIssuedOTPs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make a request to the billing overview endpoint
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "billing/overview")
|
||||
req.Data["refresh_data"] = true
|
||||
|
|
@ -366,6 +421,24 @@ func TestSystemBackend_BillingOverview_MetricFormats(t *testing.T) {
|
|||
require.True(t, ok, "managed_keys total should be int")
|
||||
require.GreaterOrEqual(t, total, 0)
|
||||
require.Contains(t, metricData, "metric_details")
|
||||
|
||||
case "ssh_units":
|
||||
require.Contains(t, metricData, "total")
|
||||
total, ok := metricData["total"].(float64)
|
||||
require.True(t, ok, "ssh_units total should be float64")
|
||||
require.GreaterOrEqual(t, total, float64(0))
|
||||
|
||||
require.Contains(t, metricData, "metric_details")
|
||||
metricDetails, ok := metricData["metric_details"].([]map[string]interface{})
|
||||
require.True(t, ok, "metric_details should be []map[string]interface{}")
|
||||
require.NotEmpty(t, metricDetails)
|
||||
require.Equal(t, len(metricDetails), 2)
|
||||
|
||||
require.Equal(t, metricDetails[0]["type"], "otp_units")
|
||||
require.GreaterOrEqual(t, metricDetails[0]["count"], uint64(0))
|
||||
|
||||
require.Equal(t, metricDetails[1]["type"], "certificate_units")
|
||||
require.GreaterOrEqual(t, metricDetails[1]["count"], float64(0))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -469,6 +542,7 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) {
|
|||
"data_protection_calls": false,
|
||||
"pki_units": false,
|
||||
"managed_keys": false,
|
||||
"ssh_units": false,
|
||||
}
|
||||
|
||||
for _, metric := range usageMetrics {
|
||||
|
|
@ -522,6 +596,10 @@ func TestSystemBackend_BillingOverview_EmptyMetrics(t *testing.T) {
|
|||
details, ok := metricData["metric_details"].([]map[string]interface{})
|
||||
require.True(t, ok, "%s metric_details should be array", metricName)
|
||||
require.Empty(t, details, "%s metric_details should be empty when total is 0", metricName)
|
||||
case "ssh_units":
|
||||
total, ok := metricData["total"].(float64)
|
||||
require.True(t, ok, "ssh_units total should be float64")
|
||||
require.Equal(t, float64(0), total, "ssh_units total should be 0")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue