Backport VAULT-41675: Transit observations, key management into ce/main (#12380)

* VAULT-41675: Transit observations, key management (#12100)

* start transit implementation

* all observations and tests

* add comments

* cleanup

* Fix broken build (#12384)

---------

Co-authored-by: miagilepner <mia.epner@hashicorp.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
Vault Automation 2026-02-18 09:19:18 -05:00 committed by GitHub
parent ccceb19d02
commit 67fb5f3eda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 389 additions and 58 deletions

View file

@ -32,6 +32,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/helper/testhelpers/observations"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/billing"
@ -135,6 +136,27 @@ func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical
return b
}
func createBackendWithObservationRecorder(t testing.TB) (*backend, logical.Storage, *observations.TestObservationRecorder) {
config := logical.TestBackendConfig()
obsRecorder := observations.NewTestObservationRecorder()
config.StorageView = &logical.InmemStorage{}
config.ObservationRecorder = obsRecorder
b, _ := Backend(context.Background(), config)
require.NotNil(t, b)
err := b.Backend.Setup(context.Background(), config)
require.NoError(t, err)
return b, config.StorageView, obsRecorder
}
func factoryWithObservationRecorder(t testing.TB) (logical.Factory, *observations.TestObservationRecorder) {
obsRecorder := observations.NewTestObservationRecorder()
return func(ctx context.Context, bc *logical.BackendConfig) (logical.Backend, error) {
bc.ObservationRecorder = obsRecorder
return Factory(ctx, bc)
}, obsRecorder
}
func TestTransit_RSA(t *testing.T) {
testTransit_RSA(t, "rsa-2048")
testTransit_RSA(t, "rsa-3072")
@ -361,56 +383,59 @@ func testTransit_RSA(t *testing.T, keyType string) {
}
func TestBackend_basic(t *testing.T) {
factory, obsRecorder := factoryWithObservationRecorder(t)
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
LogicalFactory: Factory,
LogicalFactory: factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false),
testAccStepWritePolicy(t, "test", false, obsRecorder),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepReadPolicy(t, "test", false, false, obsRecorder),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepEncrypt(t, "test", "", decryptData),
testAccStepDecrypt(t, "test", "", decryptData),
testAccStepDeleteNotDisabledPolicy(t, "test"),
testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"),
testAccStepWritePolicy(t, "test", false),
testAccStepDeletePolicy(t, "test", obsRecorder),
testAccStepWritePolicy(t, "test", false, obsRecorder),
testAccStepEnableDeletion(t, "test"),
testAccStepDisableDeletion(t, "test"),
testAccStepDeleteNotDisabledPolicy(t, "test"),
testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, false),
testAccStepDeletePolicy(t, "test", obsRecorder),
testAccStepReadPolicy(t, "test", true, false, obsRecorder),
},
})
}
func TestBackend_upsert(t *testing.T) {
factory, obsRecorder := factoryWithObservationRecorder(t)
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
LogicalFactory: Factory,
LogicalFactory: factory,
Steps: []logicaltest.TestStep{
testAccStepReadPolicy(t, "test", true, false),
testAccStepReadPolicy(t, "test", true, false, obsRecorder),
testAccStepListPolicy(t, "test", true),
testAccStepEncryptUpsert(t, "test", testPlaintext, decryptData),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepReadPolicy(t, "test", false, false, obsRecorder),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
},
})
}
func TestBackend_datakey(t *testing.T) {
factory, obsRecorder := factoryWithObservationRecorder(t)
dataKeyInfo := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
LogicalFactory: Factory,
LogicalFactory: factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false),
testAccStepWritePolicy(t, "test", false, obsRecorder),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepReadPolicy(t, "test", false, false, nil),
testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo),
testAccStepDecryptDatakey(t, "test", dataKeyInfo),
testAccStepWriteDatakey(t, "test", true, 128, dataKeyInfo),
@ -428,11 +453,12 @@ func TestBackend_rotation(t *testing.T) {
func testBackendRotation(t *testing.T) {
decryptData := make(map[string]interface{})
encryptHistory := make(map[int]map[string]interface{})
factory, obsRecorder := factoryWithObservationRecorder(t)
logicaltest.Test(t, logicaltest.TestCase{
LogicalFactory: Factory,
LogicalFactory: factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false),
testAccStepWritePolicy(t, "test", false, obsRecorder),
testAccStepListPolicy(t, "test", false),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory),
@ -460,7 +486,7 @@ func testBackendRotation(t *testing.T) {
testAccStepDeleteNotDisabledPolicy(t, "test"),
testAccStepAdjustPolicyMinDecryption(t, "test", 3),
testAccStepAdjustPolicyMinEncryption(t, "test", 4),
testAccStepReadPolicyWithVersions(t, "test", false, false, 3, 4),
testAccStepReadPolicyWithVersions(t, "test", false, false, 3, 4, obsRecorder),
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
@ -472,7 +498,7 @@ func testBackendRotation(t *testing.T) {
testAccStepLoadVX(t, "test", decryptData, 4, encryptHistory),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepAdjustPolicyMinDecryption(t, "test", 1),
testAccStepReadPolicyWithVersions(t, "test", false, false, 1, 4),
testAccStepReadPolicyWithVersions(t, "test", false, false, 1, 4, obsRecorder),
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
@ -482,8 +508,8 @@ func testBackendRotation(t *testing.T) {
testAccStepRewrap(t, "test", decryptData, 4),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, false),
testAccStepDeletePolicy(t, "test", obsRecorder),
testAccStepReadPolicy(t, "test", true, false, obsRecorder),
testAccStepListPolicy(t, "test", true),
},
})
@ -491,29 +517,43 @@ func testBackendRotation(t *testing.T) {
func TestBackend_basic_derived(t *testing.T) {
decryptData := make(map[string]interface{})
factory, obsRecorder := factoryWithObservationRecorder(t)
logicaltest.Test(t, logicaltest.TestCase{
LogicalFactory: Factory,
LogicalFactory: factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", true),
testAccStepWritePolicy(t, "test", true, obsRecorder),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, true),
testAccStepReadPolicy(t, "test", false, true, obsRecorder),
testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, true),
testAccStepDeletePolicy(t, "test", obsRecorder),
testAccStepReadPolicy(t, "test", true, true, obsRecorder),
},
})
}
func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest.TestStep {
func testAccStepWritePolicy(t *testing.T, name string, derived bool, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
ts := logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "keys/" + name,
Data: map[string]interface{}{
"derived": derived,
},
Check: func(resp *logical.Response) error {
if obsRecorder == nil {
return nil
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyWrite)
if obs == nil {
return fmt.Errorf("no observation")
}
if name != obs.Data["key_name"] {
return fmt.Errorf("expected name %s, got %s", name, obs.Data["key_name"])
}
return nil
},
}
if os.Getenv("TRANSIT_ACC_KEY_TYPE") == "CHACHA" {
ts.Data["type"] = "chacha20-poly1305"
@ -597,10 +637,24 @@ func testAccStepEnableDeletion(t *testing.T, name string) logicaltest.TestStep {
}
}
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
func testAccStepDeletePolicy(t *testing.T, name string, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "keys/" + name,
Check: func(_ *logical.Response) error {
if obsRecorder == nil {
return nil
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyDelete)
if obs == nil {
return fmt.Errorf("expected observation of type %s but got none", ObservationTypeTransitKeyDelete)
}
if obs.Data["key_name"] != name {
return fmt.Errorf("expected name %s, got %s", name, obs.Data["key_name"])
}
return nil
},
}
}
@ -621,11 +675,11 @@ func testAccStepDeleteNotDisabledPolicy(t *testing.T, name string) logicaltest.T
}
}
func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep {
return testAccStepReadPolicyWithVersions(t, name, expectNone, derived, 1, 0)
func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
return testAccStepReadPolicyWithVersions(t, name, expectNone, derived, 1, 0, obsRecorder)
}
func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, derived bool, minDecryptionVersion int, minEncryptionVersion int) logicaltest.TestStep {
func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, derived bool, minDecryptionVersion int, minEncryptionVersion int, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "keys/" + name,
@ -686,6 +740,25 @@ func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, de
if derived && d.KDF != "hkdf_sha256" {
return fmt.Errorf("bad: %#v", d)
}
if obsRecorder == nil {
return nil
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyRead)
if obs == nil {
return fmt.Errorf("expected key read observation but found none")
}
if obs.Data == nil {
return fmt.Errorf("observation data should not be nil")
}
keyName, ok := obs.Data["key_name"]
if !ok {
return fmt.Errorf("observation data missing key_name field")
}
if keyName != name {
return fmt.Errorf("observation key_name mismatch: expected %s, got %v", name, keyName)
}
return nil
},
}

View file

@ -0,0 +1,81 @@
// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: BUSL-1.1
package transit
const (
// key type observations
// ObservationTypeTransitKeyRotateSuccess is emitted when a key is successfully rotated.
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
// min_decryption_version, min_encryption_version, latest_version, exportable,
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived), managed_key_id (if managed key)
ObservationTypeTransitKeyRotateSuccess = "transit/key/rotate/success"
// ObservationTypeTransitKeyRotateFail is emitted when a key rotation fails.
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
// min_decryption_version, min_encryption_version, latest_version, exportable,
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived), managed_key_id (if managed key)
ObservationTypeTransitKeyRotateFail = "transit/key/rotate/success"
// ObservationTypeTransitKeyWrite is emitted when a new key is created.
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
// min_decryption_version, min_encryption_version, latest_version, exportable,
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived), managed_key_id (if managed key)
ObservationTypeTransitKeyWrite = "transit/key/write"
// ObservationTypeTransitKeyRead is emitted when a key is read.
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
// min_decryption_version, min_encryption_version, latest_version, exportable,
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived)
ObservationTypeTransitKeyRead = "transit/key/read"
// ObservationTypeTransitKeyDelete is emitted when a key is deleted.
// Metadata: key_name
ObservationTypeTransitKeyDelete = "transit/key/delete"
// ObservationTypeTransitKeyImport is emitted when a key is imported.
// For new key imports, metadata includes: key_name, type, derived, exportable,
// allow_plaintext_backup, auto_rotate_period
// For version imports, metadata includes: key_name, type, derived, deletion_allowed,
// min_available_version, min_decryption_version, min_encryption_version, latest_version,
// exportable, allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived), import_version
ObservationTypeTransitKeyImport = "transit/key/import"
// ObservationTypeTransitKeyExport is emitted when a key is exported.
// For full key exports, metadata includes: key_name, type, derived, deletion_allowed,
// min_available_version, min_decryption_version, min_encryption_version, latest_version,
// exportable, allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived)
// For single version exports, metadata also includes: export_version
ObservationTypeTransitKeyExport = "transit/key/export"
// ObservationTypeTransitKeyExportBYOK is emitted when a key is exported using BYOK (Bring Your Own Key).
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
// min_decryption_version, min_encryption_version, latest_version, exportable,
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived), export_version (if specified),
// destination_key
ObservationTypeTransitKeyExportBYOK = "transit/key/export/byok"
// ObservationTypeTransitKeyBackup is emitted when a key is backed up.
// Metadata: key_name
ObservationTypeTransitKeyBackup = "transit/key/backup"
// ObservationTypeTransitKeyRestore is emitted when a key is restored from backup.
// Metadata: key_name, force
ObservationTypeTransitKeyRestore = "transit/key/restore"
// ObservationTypeTransitKeyTrim is emitted when old key versions are trimmed.
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
// min_decryption_version, min_encryption_version, latest_version, exportable,
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
// kdf_mode (if derived), convergent_encryption (if derived)
ObservationTypeTransitKeyTrim = "transit/key/trim"
)

View file

@ -37,11 +37,15 @@ func (b *backend) pathBackup() *framework.Path {
}
func (b *backend) pathBackupRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
backup, err := b.lm.BackupPolicy(ctx, req.Storage, d.Get("name").(string))
name := d.Get("name").(string)
backup, err := b.lm.BackupPolicy(ctx, req.Storage, name)
if err != nil {
return nil, err
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyBackup, map[string]interface{}{
"key_name": name,
})
return &logical.Response{
Data: map[string]interface{}{
"backup": backup,

View file

@ -8,6 +8,7 @@ import (
"testing"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestTransit_BackupRestore(t *testing.T) {
@ -46,7 +47,7 @@ func testBackupRestore(t *testing.T, keyType, feature string) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
// Create a key
keyReq := &logical.Request{
@ -255,4 +256,13 @@ func testBackupRestore(t *testing.T, keyType, feature string) {
// Ensure that the restored key is functional
validationFunc("test1")
backupObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyBackup)
require.Len(t, backupObservations, 1)
require.Equal(t, "test", backupObservations[0].Data["key_name"])
restoreObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyRestore)
require.Len(t, restoreObservations, 2)
require.Equal(t, "test", restoreObservations[0].Data["key_name"])
require.Equal(t, "test1", restoreObservations[1].Data["key_name"])
}

View file

@ -103,6 +103,7 @@ func (b *backend) pathPolicyBYOKExportRead(ctx context.Context, req *logical.Req
}
retKeys := map[string]string{}
var exportVersion *int
switch version {
case "":
for k, v := range srcP.Keys {
@ -139,8 +140,16 @@ func (b *backend) pathPolicyBYOKExportRead(ctx context.Context, req *logical.Req
}
retKeys[strconv.Itoa(versionValue)] = exportKey
exportVersion = &versionValue
}
metadata := b.keyPolicyObservationMetadata(srcP)
if exportVersion != nil {
metadata["export_version"] = *exportVersion
}
metadata["destination_key"] = dstP.Name
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyExportBYOK, metadata)
resp := &logical.Response{
Data: map[string]interface{}{
"name": srcP.Name,

View file

@ -8,6 +8,7 @@ import (
"testing"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestTransit_BYOKExportImport(t *testing.T) {
@ -36,7 +37,7 @@ func testBYOKExportImport(t *testing.T, keyType, feature string) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
// Create a key
keyReq := &logical.Request{
@ -92,6 +93,11 @@ func testBYOKExportImport(t *testing.T, keyType, feature string) {
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("resp: %#v\nerr: %v", resp, err)
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyExportBYOK)
require.NotNil(t, obs)
require.Equal(t, "test-source", obs.Data["key_name"])
require.Equal(t, "wrapper", obs.Data["destination_key"])
keys := resp.Data["keys"].(map[string]string)
// Import the key to a new name.

View file

@ -129,6 +129,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
retKeys[k] = exportKey
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyExport, b.keyPolicyObservationMetadata(p))
default:
var versionValue int
if version == "latest" {
@ -155,6 +156,9 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
}
retKeys[strconv.Itoa(versionValue)] = exportKey
metadata := b.keyPolicyObservationMetadata(p)
metadata["export_version"] = versionValue
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyExport, metadata)
}
resp := &logical.Response{

View file

@ -94,7 +94,7 @@ func TestTransit_Export_KeyVersion_ExportsCorrectVersion(t *testing.T) {
func verifyExportsCorrectVersion(t *testing.T, exportType, keyType, parameterSet, ecKeyType string) {
t.Run(keyType+":"+ecKeyType, func(t *testing.T) {
b, storage := createBackendWithSysView(t)
b, storage, obsRecorder := createBackendWithObservationRecorder(t)
// First create a key, v1
req := &logical.Request{
@ -161,6 +161,10 @@ func verifyExportsCorrectVersion(t *testing.T, exportType, keyType, parameterSet
t.Fatalf("expected version %q, received version %q", strconv.Itoa(expectedVersion), k)
}
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyExport)
require.NotNil(t, obs)
require.Equal(t, obs.Data["key_name"], "foo")
require.Equal(t, obs.Data["export_version"], expectedVersion)
}
verifyVersion("v1", 1)

View file

@ -256,6 +256,15 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d *
return nil, err
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyImport, map[string]interface{}{
"key_name": name,
"type": polReq.KeyType,
"derived": polReq.Derived,
"exportable": polReq.Exportable,
"allow_plaintext_backup": polReq.AllowPlaintextBackup,
"auto_rotate_period": int64(autoRotatePeriod.Seconds()),
})
return nil, nil
}
@ -297,14 +306,16 @@ func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Reque
return resp, err
}
var versionToUpdate *int
// Get param version if set else import a new version.
if version, ok := d.GetOk("version"); ok {
versionToUpdate := version.(int)
versionValue := version.(int)
versionToUpdate = &versionValue
// Check if given version can be updated given input
err = p.KeyVersionCanBeUpdated(versionToUpdate, isCiphertextSet)
err = p.KeyVersionCanBeUpdated(*versionToUpdate, isCiphertextSet)
if err == nil {
err = p.ImportPrivateKeyForVersion(ctx, req.Storage, versionToUpdate, key)
err = p.ImportPrivateKeyForVersion(ctx, req.Storage, *versionToUpdate, key)
}
} else {
err = p.ImportPublicOrPrivate(ctx, req.Storage, key, isCiphertextSet, b.GetRandomReader())
@ -314,6 +325,12 @@ func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Reque
return nil, err
}
metadata := b.keyPolicyObservationMetadata(p)
if versionToUpdate != nil {
metadata["import_version"] = *versionToUpdate
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyImport, metadata)
return nil, nil
}

View file

@ -22,6 +22,7 @@ import (
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/cryptoutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
"github.com/tink-crypto/tink-go/v2/kwp/subtle"
)
@ -92,7 +93,7 @@ func getKey(t *testing.T, keyType string) interface{} {
func TestTransit_ImportNSSEd25519Key(t *testing.T) {
generateKeys(t)
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
wrappingKey, err := b.getWrappingKey(context.Background(), s)
if err != nil || wrappingKey == nil {
@ -121,11 +122,16 @@ func TestTransit_ImportNSSEd25519Key(t *testing.T) {
if err != nil {
t.Fatalf("failed to import NSS-formatted Ed25519 key: %v", err)
}
// Verify observation was recorded
importObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyImport)
require.Len(t, importObservations, 1)
require.Equal(t, "nss-ed25519", importObservations[0].Data["key_name"])
}
func TestTransit_ImportRSAPSS(t *testing.T) {
generateKeys(t)
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
wrappingKey, err := b.getWrappingKey(context.Background(), s)
if err != nil || wrappingKey == nil {
@ -154,12 +160,21 @@ func TestTransit_ImportRSAPSS(t *testing.T) {
if err != nil {
t.Fatalf("failed to import RSA-PSS private key: %v", err)
}
importObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyImport)
require.Len(t, importObservations, 1)
require.Equal(t, "rsa-pss", importObservations[0].Data["key_name"])
}
func TestTransit_Import(t *testing.T) {
generateKeys(t)
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
checkImportObservation := func(t *testing.T, keyName string) {
t.Helper()
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyImport)
require.NotNil(t, obs)
require.Equal(t, keyName, obs.Data["key_name"])
}
t.Run(
"import into a key fails before wrapping key is read",
func(t *testing.T) {
@ -259,6 +274,7 @@ func TestTransit_Import(t *testing.T) {
if err != nil {
t.Fatalf("failed to import valid key: %s", err)
}
checkImportObservation(t, keyID)
},
)
@ -380,6 +396,7 @@ func TestTransit_Import(t *testing.T) {
t.Fatalf("failed to import key: %s", err)
}
checkImportObservation(t, keyID)
// Rotate key
req = &logical.Request{
Storage: s,
@ -390,6 +407,10 @@ func TestTransit_Import(t *testing.T) {
if err != nil {
t.Fatalf("failed to rotate key: %s", err)
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyRotateSuccess)
require.NotNil(t, obs)
require.Equal(t, obs.Data["key_name"], keyID)
},
)
@ -418,6 +439,8 @@ func TestTransit_Import(t *testing.T) {
t.Fatalf("failed to import key: %s", err)
}
checkImportObservation(t, keyID)
// Rotate key
req = &logical.Request{
Storage: s,
@ -461,6 +484,7 @@ func TestTransit_Import(t *testing.T) {
if err != nil {
t.Fatalf("failed to import ed25519 key: %v", err)
}
checkImportObservation(t, keyID)
})
t.Run(
@ -493,12 +517,13 @@ func TestTransit_Import(t *testing.T) {
if err != nil {
t.Fatalf("failed to import public key: %s", err)
}
checkImportObservation(t, keyID)
})
}
func TestTransit_ImportVersion(t *testing.T) {
generateKeys(t)
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
t.Run(
"import into a key version fails before wrapping key is read",
@ -686,13 +711,19 @@ func TestTransit_ImportVersion(t *testing.T) {
if err != nil {
t.Fatalf("failed to update key: %s", err)
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyImport)
require.NotNil(t, obs)
require.Equal(t, keyID, obs.Data["key_name"])
require.Equal(t, keyType, obs.Data["type"])
require.NotContains(t, obs.Data, "import_version")
},
)
}
func TestTransit_ImportVersionWithPublicKeys(t *testing.T) {
generateKeys(t)
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
// Retrieve public wrapping key
wrappingKey, err := b.getWrappingKey(context.Background(), s)
@ -931,6 +962,11 @@ func TestTransit_ImportVersionWithPublicKeys(t *testing.T) {
t.Fatalf("failed to import private key: %s", err)
}
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyImport)
require.NotNil(t, obs)
require.Equal(t, keyID, obs.Data["key_name"])
require.Equal(t, 1, obs.Data["import_version"])
// We should still have two keys on export
req = &logical.Request{
Storage: s,

View file

@ -348,6 +348,12 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d *
if !upserted {
resp.AddWarning(fmt.Sprintf("key %s already existed", name))
}
metadata := b.keyPolicyObservationMetadata(p)
if polReq.ManagedKeyUUID != "" {
metadata["managed_key_id"] = polReq.ManagedKeyUUID
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyWrite, metadata)
return resp, nil
}
@ -387,9 +393,51 @@ func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *f
}
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRead, b.keyPolicyObservationMetadata(p))
return b.formatKeyPolicy(p, context)
}
func (b *backend) keyPolicyObservationMetadata(p *keysutil.Policy) map[string]interface{} {
metadata := map[string]interface{}{
"key_name": p.Name,
"type": p.Type.String(),
"derived": p.Derived,
"deletion_allowed": p.DeletionAllowed,
"min_available_version": p.MinAvailableVersion,
"min_decryption_version": p.MinDecryptionVersion,
"min_encryption_version": p.MinEncryptionVersion,
"latest_version": p.LatestVersion,
"exportable": p.Exportable,
"allow_plaintext_backup": p.AllowPlaintextBackup,
"auto_rotate_period": int64(p.AutoRotatePeriod.Seconds()),
"imported_key": p.Imported,
}
if p.Derived {
switch p.KDF {
case keysutil.Kdf_hmac_sha256_counter:
metadata["kdf"] = "hmac-sha256-counter"
metadata["kdf_mode"] = "hmac-sha256-counter"
case keysutil.Kdf_hkdf_sha256:
metadata["kdf"] = "hkdf_sha256"
}
metadata["convergent_encryption"] = p.ConvergentEncryption
if p.ConvergentEncryption {
metadata["convergent_encryption_version"] = p.ConvergentVersion
}
}
if p.ParameterSet != "" {
metadata["parameter_set"] = p.ParameterSet
}
if p.Type == keysutil.KeyType_HYBRID {
metadata["hybrid_key_type_pqc"] = p.HybridConfig.PQCKeyType.String()
metadata["hybrid_key_type_ec"] = p.HybridConfig.ECKeyType.String()
}
return metadata
}
func (b *backend) formatKeyPolicy(p *keysutil.Policy, context []byte) (*logical.Response, error) {
// Return the response
resp := &logical.Response{
@ -548,6 +596,9 @@ func (b *backend) pathPolicyDelete(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyDelete, map[string]interface{}{
"key_name": name,
})
return nil, nil
}

View file

@ -61,7 +61,16 @@ func (b *backend) pathRestoreUpdate(ctx context.Context, req *logical.Request, d
return nil, ErrInvalidKeyName
}
return nil, b.lm.RestorePolicy(ctx, req.Storage, keyName, backupB64, force)
restoredKeyName, err := b.lm.RestorePolicy(ctx, req.Storage, keyName, backupB64, force)
if err != nil {
return nil, err
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRestore, map[string]interface{}{
"key_name": restoredKeyName,
"force": force,
})
// ignore-nil-nil-function-check
return nil, nil
}
const (

View file

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/vault/helper/testhelpers"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestTransit_Restore(t *testing.T) {
@ -25,7 +26,7 @@ func TestTransit_Restore(t *testing.T) {
// as if the key already existed
keyType := "aes256-gcm96"
b, s := createBackendWithStorage(t)
b, s, obsRecorder := createBackendWithObservationRecorder(t)
keyName := testhelpers.RandomWithPrefix("my-key")
// Create a key
@ -227,6 +228,16 @@ func TestTransit_Restore(t *testing.T) {
readKeyName = tc.RestoreName
}
if tc.ExpectedErr == nil {
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyRestore)
require.NotNil(t, obs)
require.Equal(t, obs.Data["key_name"], readKeyName)
force := false
if tc.Force != nil {
force = *tc.Force
}
require.Equal(t, obs.Data["force"], force)
}
// read the key and make sure it's there
readReq := &logical.Request{
Path: "keys/" + readKeyName,

View file

@ -68,9 +68,8 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d *
p.Lock(true)
}
defer p.Unlock()
var keyId string
if p.Type == keysutil.KeyType_MANAGED_KEY {
var keyId string
keyId, err = GetManagedKeyUUID(ctx, b, managedKeyName, managedKeyId)
if err != nil {
b.Logger().Error("failed to rotate key", "name", name, "error", err.Error())
@ -82,17 +81,26 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d *
err = p.Rotate(ctx, req.Storage, b.GetRandomReader())
}
keyMetadata := b.keyPolicyObservationMetadata(p)
if p.Type == keysutil.KeyType_MANAGED_KEY && keyId != "" {
keyMetadata["managed_key_id"] = keyId
}
if err != nil {
b.Logger().Error("failed to rotate key on user request", "name", name, "error", err.Error())
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRotateFail, keyMetadata)
return nil, err
}
resp, err := b.formatKeyPolicy(p, nil)
if err != nil {
b.Logger().Error("failed to rotate key on user request", "name", name, "error", err.Error())
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRotateFail, keyMetadata)
} else {
b.Logger().Info("successfully rotated key on user request", "name", name)
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRotateSuccess, keyMetadata)
}
// formatKeyPolicy returns a response even on error so be sure to return both.
return resp, err
}

View file

@ -100,6 +100,8 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc {
return nil, err
}
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyTrim, b.keyPolicyObservationMetadata(p))
return b.formatKeyPolicy(p, nil)
}
}

View file

@ -9,10 +9,11 @@ import (
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestTransit_Trim(t *testing.T) {
b, storage := createBackendWithSysView(t)
b, storage, obsRecorder := createBackendWithObservationRecorder(t)
doReq := func(t *testing.T, req *logical.Request) *logical.Response {
t.Helper()
@ -270,4 +271,9 @@ func TestTransit_Trim(t *testing.T) {
if len(archive.Keys) != 4 {
t.Fatalf("bad: len of archived keys; expected: 4, actual: %d", len(archive.Keys))
}
trimObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyTrim)
require.Len(t, trimObservations, 2)
require.Equal(t, trimObservations[0].Data["key_name"], "aes")
require.Equal(t, trimObservations[1].Data["key_name"], "aes")
}

View file

@ -150,21 +150,21 @@ func (lm *LockManager) InitCache(cacheSize int) error {
// RestorePolicy acquires an exclusive lock on the policy name and restores the
// given policy along with the archive.
func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) (string, error) {
backupBytes, err := base64.StdEncoding.DecodeString(backup)
if err != nil {
return err
return "", err
}
var keyData KeyData
err = jsonutil.DecodeJSON(backupBytes, &keyData)
if err != nil {
return err
return "", err
}
// Validate that the policy exists in the backup data
if keyData.Policy == nil {
return errors.New("backup data does not contain a valid policy")
return "", errors.New("backup data does not contain a valid policy")
}
// Set a different name if desired
@ -188,7 +188,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
if lm.useCache {
pRaw, ok = lm.cache.Load(name)
if ok && !force {
return fmt.Errorf("key %q already exists", name)
return "", fmt.Errorf("key %q already exists", name)
}
}
@ -207,10 +207,10 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
if pRaw == nil {
p, err = lm.getPolicyFromStorage(ctx, storage, name)
if err != nil {
return err
return "", err
}
if p != nil && !force {
return fmt.Errorf("key %q already exists", name)
return "", fmt.Errorf("key %q already exists", name)
}
}
@ -230,7 +230,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
if keyData.ArchivedKeys != nil {
err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
return "", errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
}
}
@ -243,7 +243,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
// Restore the policy. This will also attempt to adjust the archive.
err = keyData.Policy.Persist(ctx, storage)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
return "", errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
}
keyData.Policy.l = new(sync.RWMutex)
@ -252,7 +252,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
if lm.useCache {
lm.cache.Store(name, keyData.Policy)
}
return nil
return name, nil
}
func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {

View file

@ -110,7 +110,7 @@ func TestRestorePolicy_NilPolicy(t *testing.T) {
// Create backup data without "policy" field (causes nil Policy)
invalidBackup := base64.StdEncoding.EncodeToString([]byte(`{"archived_keys": null}`))
err = lm.RestorePolicy(ctx, storage, "test-key", invalidBackup, false)
_, err = lm.RestorePolicy(ctx, storage, "test-key", invalidBackup, false)
require.Error(t, err)
require.Contains(t, err.Error(), "backup data does not contain a valid policy")
}