diff --git a/buf.yaml b/buf.yaml index 4b31f93fb9..a554c6d827 100644 --- a/buf.yaml +++ b/buf.yaml @@ -51,6 +51,7 @@ lint: - sdk/plugin/pb/backend.proto - sdk/plugin/pb/system_view_service_ent.proto - vault/activity/activity_log.proto + - vault/billing/billing.proto - sdk/helper/clientcountutil/generation/generate_data.proto - vault/hcp_link/proto/link_control/link_control.proto - vault/hcp_link/proto/meta/meta.proto @@ -85,6 +86,7 @@ lint: - sdk/plugin/pb/backend.proto - sdk/plugin/pb/system_view_service_ent.proto - vault/activity/activity_log.proto + - vault/billing/billing.proto - sdk/helper/clientcountutil/generation/generate_data.proto - vault/hcp_link/proto/link_control/link_control.proto - vault/hcp_link/proto/meta/meta.proto diff --git a/vault/billing/billing.pb.go b/vault/billing/billing.pb.go new file mode 100644 index 0000000000..463fccd4b6 --- /dev/null +++ b/vault/billing/billing.pb.go @@ -0,0 +1,150 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.5 +// protoc (unknown) +// source: vault/billing/billing.proto + +package billing + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + anypb "google.golang.org/protobuf/types/known/anypb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// PluginBillingDataRequest contains the in-memory data protection call counts +// from a performance standby node to be sent to the active node +type PluginBillingDataRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Map of plugin type to count (e.g., "transit" -> count) + PluginData map[string]*anypb.Any `protobuf:"bytes,1,rep,name=plugin_data,json=pluginData,proto3" json:"plugin_data,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PluginBillingDataRequest) Reset() { + *x = PluginBillingDataRequest{} + mi := &file_vault_billing_billing_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PluginBillingDataRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PluginBillingDataRequest) ProtoMessage() {} + +func (x *PluginBillingDataRequest) ProtoReflect() protoreflect.Message { + mi := &file_vault_billing_billing_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PluginBillingDataRequest.ProtoReflect.Descriptor instead. +func (*PluginBillingDataRequest) Descriptor() ([]byte, []int) { + return file_vault_billing_billing_proto_rawDescGZIP(), []int{0} +} + +func (x *PluginBillingDataRequest) GetPluginData() map[string]*anypb.Any { + if x != nil { + return x.PluginData + } + return nil +} + +var File_vault_billing_billing_proto protoreflect.FileDescriptor + +var file_vault_billing_billing_proto_rawDesc = string([]byte{ + 0x0a, 0x1b, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x62, 0x69, 0x6c, 0x6c, 0x69, 0x6e, 0x67, 0x2f, + 0x62, 0x69, 0x6c, 0x6c, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x62, + 0x69, 0x6c, 0x6c, 0x69, 0x6e, 0x67, 0x1a, 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x22, 0xc3, 0x01, 0x0a, 0x18, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x42, 0x69, 0x6c, 0x6c, + 0x69, 0x6e, 0x67, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x52, + 0x0a, 0x0b, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x31, 0x2e, 0x62, 0x69, 0x6c, 0x6c, 0x69, 0x6e, 0x67, 0x2e, 0x50, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x42, 0x69, 0x6c, 0x6c, 0x69, 0x6e, 0x67, 0x44, 0x61, 0x74, 0x61, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x44, 0x61, 0x74, + 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0a, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x44, 0x61, + 0x74, 0x61, 0x1a, 0x53, 0x0a, 0x0f, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x44, 0x61, 0x74, 0x61, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2a, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2a, 0x5a, 0x28, 0x67, 0x69, 0x74, 0x68, 0x75, + 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, + 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x62, 0x69, 0x6c, 0x6c, + 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_vault_billing_billing_proto_rawDescOnce sync.Once + file_vault_billing_billing_proto_rawDescData []byte +) + +func file_vault_billing_billing_proto_rawDescGZIP() []byte { + file_vault_billing_billing_proto_rawDescOnce.Do(func() { + file_vault_billing_billing_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_vault_billing_billing_proto_rawDesc), len(file_vault_billing_billing_proto_rawDesc))) + }) + return file_vault_billing_billing_proto_rawDescData +} + +var file_vault_billing_billing_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_vault_billing_billing_proto_goTypes = []any{ + (*PluginBillingDataRequest)(nil), // 0: billing.PluginBillingDataRequest + nil, // 1: billing.PluginBillingDataRequest.PluginDataEntry + (*anypb.Any)(nil), // 2: google.protobuf.Any +} +var file_vault_billing_billing_proto_depIdxs = []int32{ + 1, // 0: billing.PluginBillingDataRequest.plugin_data:type_name -> billing.PluginBillingDataRequest.PluginDataEntry + 2, // 1: billing.PluginBillingDataRequest.PluginDataEntry.value:type_name -> google.protobuf.Any + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_vault_billing_billing_proto_init() } +func file_vault_billing_billing_proto_init() { + if File_vault_billing_billing_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_vault_billing_billing_proto_rawDesc), len(file_vault_billing_billing_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_vault_billing_billing_proto_goTypes, + DependencyIndexes: file_vault_billing_billing_proto_depIdxs, + MessageInfos: file_vault_billing_billing_proto_msgTypes, + }.Build() + File_vault_billing_billing_proto = out.File + file_vault_billing_billing_proto_goTypes = nil + file_vault_billing_billing_proto_depIdxs = nil +} diff --git a/vault/billing/billing.proto b/vault/billing/billing.proto new file mode 100644 index 0000000000..ca9b88074e --- /dev/null +++ b/vault/billing/billing.proto @@ -0,0 +1,17 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +syntax = "proto3"; + +package billing; + +import "google/protobuf/any.proto"; + +option go_package = "github.com/hashicorp/vault/vault/billing"; + +// PluginBillingDataRequest contains the in-memory data protection call counts +// from a performance standby node to be sent to the active node +message PluginBillingDataRequest { + // Map of plugin type to count (e.g., "transit" -> count) + map plugin_data = 1; +} diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index e8d008424a..19e61404df 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -46,6 +46,8 @@ type ConsumptionBilling struct { type BillingConfig struct { // For testing purposes. The cadence at which billing metrics are updated MetricsUpdateCadence time.Duration + // For testing purposes. The cadence at which plugin counts are sent from perf standby to active + PluginCountsSendCadence time.Duration } func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric string) string { @@ -65,6 +67,10 @@ type DataProtectionCallCounts struct { var _ logical.ConsumptionBillingManager = (*ConsumptionBilling)(nil) func (s *ConsumptionBilling) WriteBillingData(ctx context.Context, mountType string, data map[string]interface{}) error { + if s == nil { + return nil + } + switch mountType { case "transit": val, ok := data["count"].(uint64) diff --git a/vault/consumption_billing.go b/vault/consumption_billing.go index 97df72b9e3..cb25882b8c 100644 --- a/vault/consumption_billing.go +++ b/vault/consumption_billing.go @@ -30,6 +30,13 @@ func (c *Core) setupConsumptionBilling(ctx context.Context) error { c.consumptionBillingLock.Unlock() c.postUnsealFuncs = append(c.postUnsealFuncs, func() { c.consumptionBillingMetricsWorker(ctx) + // Start the perf standby plugin counts worker if this is a perf standby + // Access perfStandby field directly to avoid deadlock during post-unseal + if c.perfStandby { + go c.perfStandbyPluginCountsWorker(ctx) + } + // Active nodes don't need a separate worker - they flush counts via + // the existing consumptionBillingMetricsWorker -> updateBillingMetrics path }) return nil diff --git a/vault/consumption_billing_util_oss.go b/vault/consumption_billing_util_oss.go new file mode 100644 index 0000000000..f5eecd133b --- /dev/null +++ b/vault/consumption_billing_util_oss.go @@ -0,0 +1,20 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +//go:build !enterprise + +package vault + +import ( + "context" +) + +// sendPluginCounts is a no-op on OSS +func (c *Core) sendPluginCounts(ctx context.Context) error { + return nil +} + +// perfStandbyPluginCountsWorker is a no-op on OSS +func (c *Core) perfStandbyPluginCountsWorker(ctx context.Context) { + // No-op: performance standby plugin counts worker is enterprise-only +} diff --git a/vault/logical_system_use_case_billing.go b/vault/logical_system_use_case_billing.go index 84de05bea6..0517fb6d30 100644 --- a/vault/logical_system_use_case_billing.go +++ b/vault/logical_system_use_case_billing.go @@ -87,9 +87,13 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic // Data protection call counts are stored to local path only // Each cluster tracks its own total requests to avoid double counting - localDataProtectionCallCounts, err := b.Core.UpdateTransitCallCounts(ctx, currentMonth) + localTransitCallCounts, err := b.Core.UpdateTransitCallCounts(ctx, currentMonth) if err != nil { - return nil, fmt.Errorf("error retrieving local max data protection call counts: %w", err) + return nil, fmt.Errorf("error retrieving local transit call counts: %w", err) + } + localTransformCallCounts, err := b.Core.UpdateTransformCallCounts(ctx, currentMonth) + if err != nil { + return nil, fmt.Errorf("error retrieving local transform call counts: %w", err) } // If we are the primary, then combine the replicated and local max role counts. Else just output the local @@ -97,7 +101,10 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic combinedMaxRoleCounts := combineRoleCounts(replicatedMaxRoleCounts, localMaxRoleCounts) combinedMaxKvCounts := replicatedKvHWMCounts + localKvHWMCounts // Data protection counts are not combined - each cluster reports its own total - combinedMaxDataProtectionCallCounts := localDataProtectionCallCounts + combinedMaxDataProtectionCallCounts := map[string]interface{}{ + "transit": localTransitCallCounts, + "transform": localTransformCallCounts, + } var replicatedPreviousMonthRoleCounts *RoleCounts replicatedPreviousMonthKvHWMCounts := 0 @@ -123,13 +130,20 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic // Data protection counts for previous month localPreviousMonthTransitCallCounts, err := b.Core.GetStoredTransitCallCounts(ctx, previousMonth) if err != nil { - return nil, fmt.Errorf("error retrieving local max data protection call counts for previous month: %w", err) + return nil, fmt.Errorf("error retrieving local transit call counts for previous month: %w", err) + } + localPreviousMonthTransformCallCounts, err := b.Core.GetStoredTransformCallCounts(ctx, previousMonth) + if err != nil { + return nil, fmt.Errorf("error retrieving local transform call counts for previous month: %w", err) } combinedPreviousMonthRoleCounts := combineRoleCounts(replicatedPreviousMonthRoleCounts, localPreviousMonthRoleCounts) combinedPreviousMonthKvHWMCounts := replicatedPreviousMonthKvHWMCounts + localPreviousMonthKvHWMCounts // Data protection counts are not combined - each cluster reports its own total - combinedPreviousMonthTransitCallCounts := localPreviousMonthTransitCallCounts + combinedPreviousMonthDataProtectionCallCounts := map[string]interface{}{ + "transit": localPreviousMonthTransitCallCounts, + "transform": localPreviousMonthTransformCallCounts, + } resp := map[string]interface{}{ "current_month": map[string]interface{}{ @@ -139,11 +153,10 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic "data_protection_call_counts": combinedMaxDataProtectionCallCounts, }, "previous_month": map[string]interface{}{ - "timestamp": previousMonth, - "maximum_role_counts": combinedPreviousMonthRoleCounts, - "maximum_kv_counts": combinedPreviousMonthKvHWMCounts, - // TODO: Add transform data protection call counts - "data_protection_call_counts": combinedPreviousMonthTransitCallCounts, + "timestamp": previousMonth, + "maximum_role_counts": combinedPreviousMonthRoleCounts, + "maximum_kv_counts": combinedPreviousMonthKvHWMCounts, + "data_protection_call_counts": combinedPreviousMonthDataProtectionCallCounts, }, }