mirror of
https://github.com/hashicorp/vault.git
synced 2026-02-18 10:28:13 -05:00
Backport VAULT-41675: Transit observations, key management into ce/main (#12380)
* VAULT-41675: Transit observations, key management (#12100) * start transit implementation * all observations and tests * add comments * cleanup * Fix broken build (#12384) --------- Co-authored-by: miagilepner <mia.epner@hashicorp.com> Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
parent
ccceb19d02
commit
67fb5f3eda
18 changed files with 389 additions and 58 deletions
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/testhelpers/observations"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
|
|
@ -135,6 +136,27 @@ func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical
|
|||
return b
|
||||
}
|
||||
|
||||
func createBackendWithObservationRecorder(t testing.TB) (*backend, logical.Storage, *observations.TestObservationRecorder) {
|
||||
config := logical.TestBackendConfig()
|
||||
obsRecorder := observations.NewTestObservationRecorder()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.ObservationRecorder = obsRecorder
|
||||
|
||||
b, _ := Backend(context.Background(), config)
|
||||
require.NotNil(t, b)
|
||||
err := b.Backend.Setup(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
return b, config.StorageView, obsRecorder
|
||||
}
|
||||
|
||||
func factoryWithObservationRecorder(t testing.TB) (logical.Factory, *observations.TestObservationRecorder) {
|
||||
obsRecorder := observations.NewTestObservationRecorder()
|
||||
return func(ctx context.Context, bc *logical.BackendConfig) (logical.Backend, error) {
|
||||
bc.ObservationRecorder = obsRecorder
|
||||
return Factory(ctx, bc)
|
||||
}, obsRecorder
|
||||
}
|
||||
|
||||
func TestTransit_RSA(t *testing.T) {
|
||||
testTransit_RSA(t, "rsa-2048")
|
||||
testTransit_RSA(t, "rsa-3072")
|
||||
|
|
@ -361,56 +383,59 @@ func testTransit_RSA(t *testing.T, keyType string) {
|
|||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
factory, obsRecorder := factoryWithObservationRecorder(t)
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalFactory: Factory,
|
||||
LogicalFactory: factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepListPolicy(t, "test", true),
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepWritePolicy(t, "test", false, obsRecorder),
|
||||
testAccStepListPolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepReadPolicy(t, "test", false, false, obsRecorder),
|
||||
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEncrypt(t, "test", "", decryptData),
|
||||
testAccStepDecrypt(t, "test", "", decryptData),
|
||||
testAccStepDeleteNotDisabledPolicy(t, "test"),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepDeletePolicy(t, "test", obsRecorder),
|
||||
testAccStepWritePolicy(t, "test", false, obsRecorder),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDisableDeletion(t, "test"),
|
||||
testAccStepDeleteNotDisabledPolicy(t, "test"),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
testAccStepDeletePolicy(t, "test", obsRecorder),
|
||||
testAccStepReadPolicy(t, "test", true, false, obsRecorder),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_upsert(t *testing.T) {
|
||||
factory, obsRecorder := factoryWithObservationRecorder(t)
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalFactory: Factory,
|
||||
LogicalFactory: factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
testAccStepReadPolicy(t, "test", true, false, obsRecorder),
|
||||
testAccStepListPolicy(t, "test", true),
|
||||
testAccStepEncryptUpsert(t, "test", testPlaintext, decryptData),
|
||||
testAccStepListPolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepReadPolicy(t, "test", false, false, obsRecorder),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_datakey(t *testing.T) {
|
||||
factory, obsRecorder := factoryWithObservationRecorder(t)
|
||||
dataKeyInfo := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalFactory: Factory,
|
||||
LogicalFactory: factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepListPolicy(t, "test", true),
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepWritePolicy(t, "test", false, obsRecorder),
|
||||
testAccStepListPolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepReadPolicy(t, "test", false, false, nil),
|
||||
testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo),
|
||||
testAccStepDecryptDatakey(t, "test", dataKeyInfo),
|
||||
testAccStepWriteDatakey(t, "test", true, 128, dataKeyInfo),
|
||||
|
|
@ -428,11 +453,12 @@ func TestBackend_rotation(t *testing.T) {
|
|||
func testBackendRotation(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
encryptHistory := make(map[int]map[string]interface{})
|
||||
factory, obsRecorder := factoryWithObservationRecorder(t)
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalFactory: Factory,
|
||||
LogicalFactory: factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepListPolicy(t, "test", true),
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepWritePolicy(t, "test", false, obsRecorder),
|
||||
testAccStepListPolicy(t, "test", false),
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory),
|
||||
|
|
@ -460,7 +486,7 @@ func testBackendRotation(t *testing.T) {
|
|||
testAccStepDeleteNotDisabledPolicy(t, "test"),
|
||||
testAccStepAdjustPolicyMinDecryption(t, "test", 3),
|
||||
testAccStepAdjustPolicyMinEncryption(t, "test", 4),
|
||||
testAccStepReadPolicyWithVersions(t, "test", false, false, 3, 4),
|
||||
testAccStepReadPolicyWithVersions(t, "test", false, false, 3, 4, obsRecorder),
|
||||
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
|
||||
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
|
||||
|
|
@ -472,7 +498,7 @@ func testBackendRotation(t *testing.T) {
|
|||
testAccStepLoadVX(t, "test", decryptData, 4, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepAdjustPolicyMinDecryption(t, "test", 1),
|
||||
testAccStepReadPolicyWithVersions(t, "test", false, false, 1, 4),
|
||||
testAccStepReadPolicyWithVersions(t, "test", false, false, 1, 4, obsRecorder),
|
||||
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
|
||||
|
|
@ -482,8 +508,8 @@ func testBackendRotation(t *testing.T) {
|
|||
testAccStepRewrap(t, "test", decryptData, 4),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
testAccStepDeletePolicy(t, "test", obsRecorder),
|
||||
testAccStepReadPolicy(t, "test", true, false, obsRecorder),
|
||||
testAccStepListPolicy(t, "test", true),
|
||||
},
|
||||
})
|
||||
|
|
@ -491,29 +517,43 @@ func testBackendRotation(t *testing.T) {
|
|||
|
||||
func TestBackend_basic_derived(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
factory, obsRecorder := factoryWithObservationRecorder(t)
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalFactory: Factory,
|
||||
LogicalFactory: factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepListPolicy(t, "test", true),
|
||||
testAccStepWritePolicy(t, "test", true),
|
||||
testAccStepWritePolicy(t, "test", true, obsRecorder),
|
||||
testAccStepListPolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, true),
|
||||
testAccStepReadPolicy(t, "test", false, true, obsRecorder),
|
||||
testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, true),
|
||||
testAccStepDeletePolicy(t, "test", obsRecorder),
|
||||
testAccStepReadPolicy(t, "test", true, true, obsRecorder),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest.TestStep {
|
||||
func testAccStepWritePolicy(t *testing.T, name string, derived bool, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
|
||||
ts := logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "keys/" + name,
|
||||
Data: map[string]interface{}{
|
||||
"derived": derived,
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
if obsRecorder == nil {
|
||||
return nil
|
||||
}
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyWrite)
|
||||
if obs == nil {
|
||||
return fmt.Errorf("no observation")
|
||||
}
|
||||
if name != obs.Data["key_name"] {
|
||||
return fmt.Errorf("expected name %s, got %s", name, obs.Data["key_name"])
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if os.Getenv("TRANSIT_ACC_KEY_TYPE") == "CHACHA" {
|
||||
ts.Data["type"] = "chacha20-poly1305"
|
||||
|
|
@ -597,10 +637,24 @@ func testAccStepEnableDeletion(t *testing.T, name string) logicaltest.TestStep {
|
|||
}
|
||||
}
|
||||
|
||||
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
|
||||
func testAccStepDeletePolicy(t *testing.T, name string, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "keys/" + name,
|
||||
Check: func(_ *logical.Response) error {
|
||||
if obsRecorder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyDelete)
|
||||
if obs == nil {
|
||||
return fmt.Errorf("expected observation of type %s but got none", ObservationTypeTransitKeyDelete)
|
||||
}
|
||||
if obs.Data["key_name"] != name {
|
||||
return fmt.Errorf("expected name %s, got %s", name, obs.Data["key_name"])
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -621,11 +675,11 @@ func testAccStepDeleteNotDisabledPolicy(t *testing.T, name string) logicaltest.T
|
|||
}
|
||||
}
|
||||
|
||||
func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep {
|
||||
return testAccStepReadPolicyWithVersions(t, name, expectNone, derived, 1, 0)
|
||||
func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
|
||||
return testAccStepReadPolicyWithVersions(t, name, expectNone, derived, 1, 0, obsRecorder)
|
||||
}
|
||||
|
||||
func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, derived bool, minDecryptionVersion int, minEncryptionVersion int) logicaltest.TestStep {
|
||||
func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, derived bool, minDecryptionVersion int, minEncryptionVersion int, obsRecorder *observations.TestObservationRecorder) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "keys/" + name,
|
||||
|
|
@ -686,6 +740,25 @@ func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, de
|
|||
if derived && d.KDF != "hkdf_sha256" {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
|
||||
if obsRecorder == nil {
|
||||
return nil
|
||||
}
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyRead)
|
||||
if obs == nil {
|
||||
return fmt.Errorf("expected key read observation but found none")
|
||||
}
|
||||
if obs.Data == nil {
|
||||
return fmt.Errorf("observation data should not be nil")
|
||||
}
|
||||
keyName, ok := obs.Data["key_name"]
|
||||
if !ok {
|
||||
return fmt.Errorf("observation data missing key_name field")
|
||||
}
|
||||
if keyName != name {
|
||||
return fmt.Errorf("observation key_name mismatch: expected %s, got %v", name, keyName)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
|
|
|||
81
builtin/logical/transit/observation_consts.go
Normal file
81
builtin/logical/transit/observation_consts.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright IBM Corp. 2016, 2025
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package transit
|
||||
|
||||
const (
|
||||
|
||||
// key type observations
|
||||
|
||||
// ObservationTypeTransitKeyRotateSuccess is emitted when a key is successfully rotated.
|
||||
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
|
||||
// min_decryption_version, min_encryption_version, latest_version, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived), managed_key_id (if managed key)
|
||||
ObservationTypeTransitKeyRotateSuccess = "transit/key/rotate/success"
|
||||
|
||||
// ObservationTypeTransitKeyRotateFail is emitted when a key rotation fails.
|
||||
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
|
||||
// min_decryption_version, min_encryption_version, latest_version, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived), managed_key_id (if managed key)
|
||||
ObservationTypeTransitKeyRotateFail = "transit/key/rotate/success"
|
||||
|
||||
// ObservationTypeTransitKeyWrite is emitted when a new key is created.
|
||||
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
|
||||
// min_decryption_version, min_encryption_version, latest_version, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived), managed_key_id (if managed key)
|
||||
ObservationTypeTransitKeyWrite = "transit/key/write"
|
||||
|
||||
// ObservationTypeTransitKeyRead is emitted when a key is read.
|
||||
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
|
||||
// min_decryption_version, min_encryption_version, latest_version, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived)
|
||||
ObservationTypeTransitKeyRead = "transit/key/read"
|
||||
|
||||
// ObservationTypeTransitKeyDelete is emitted when a key is deleted.
|
||||
// Metadata: key_name
|
||||
ObservationTypeTransitKeyDelete = "transit/key/delete"
|
||||
|
||||
// ObservationTypeTransitKeyImport is emitted when a key is imported.
|
||||
// For new key imports, metadata includes: key_name, type, derived, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period
|
||||
// For version imports, metadata includes: key_name, type, derived, deletion_allowed,
|
||||
// min_available_version, min_decryption_version, min_encryption_version, latest_version,
|
||||
// exportable, allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived), import_version
|
||||
ObservationTypeTransitKeyImport = "transit/key/import"
|
||||
|
||||
// ObservationTypeTransitKeyExport is emitted when a key is exported.
|
||||
// For full key exports, metadata includes: key_name, type, derived, deletion_allowed,
|
||||
// min_available_version, min_decryption_version, min_encryption_version, latest_version,
|
||||
// exportable, allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived)
|
||||
// For single version exports, metadata also includes: export_version
|
||||
ObservationTypeTransitKeyExport = "transit/key/export"
|
||||
|
||||
// ObservationTypeTransitKeyExportBYOK is emitted when a key is exported using BYOK (Bring Your Own Key).
|
||||
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
|
||||
// min_decryption_version, min_encryption_version, latest_version, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived), export_version (if specified),
|
||||
// destination_key
|
||||
ObservationTypeTransitKeyExportBYOK = "transit/key/export/byok"
|
||||
|
||||
// ObservationTypeTransitKeyBackup is emitted when a key is backed up.
|
||||
// Metadata: key_name
|
||||
ObservationTypeTransitKeyBackup = "transit/key/backup"
|
||||
|
||||
// ObservationTypeTransitKeyRestore is emitted when a key is restored from backup.
|
||||
// Metadata: key_name, force
|
||||
ObservationTypeTransitKeyRestore = "transit/key/restore"
|
||||
|
||||
// ObservationTypeTransitKeyTrim is emitted when old key versions are trimmed.
|
||||
// Metadata: key_name, type, derived, deletion_allowed, min_available_version,
|
||||
// min_decryption_version, min_encryption_version, latest_version, exportable,
|
||||
// allow_plaintext_backup, auto_rotate_period, imported_key, kdf (if derived),
|
||||
// kdf_mode (if derived), convergent_encryption (if derived)
|
||||
ObservationTypeTransitKeyTrim = "transit/key/trim"
|
||||
)
|
||||
|
|
@ -37,11 +37,15 @@ func (b *backend) pathBackup() *framework.Path {
|
|||
}
|
||||
|
||||
func (b *backend) pathBackupRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
backup, err := b.lm.BackupPolicy(ctx, req.Storage, d.Get("name").(string))
|
||||
name := d.Get("name").(string)
|
||||
backup, err := b.lm.BackupPolicy(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyBackup, map[string]interface{}{
|
||||
"key_name": name,
|
||||
})
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"backup": backup,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_BackupRestore(t *testing.T) {
|
||||
|
|
@ -46,7 +47,7 @@ func testBackupRestore(t *testing.T, keyType, feature string) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
// Create a key
|
||||
keyReq := &logical.Request{
|
||||
|
|
@ -255,4 +256,13 @@ func testBackupRestore(t *testing.T, keyType, feature string) {
|
|||
|
||||
// Ensure that the restored key is functional
|
||||
validationFunc("test1")
|
||||
|
||||
backupObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyBackup)
|
||||
require.Len(t, backupObservations, 1)
|
||||
require.Equal(t, "test", backupObservations[0].Data["key_name"])
|
||||
|
||||
restoreObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyRestore)
|
||||
require.Len(t, restoreObservations, 2)
|
||||
require.Equal(t, "test", restoreObservations[0].Data["key_name"])
|
||||
require.Equal(t, "test1", restoreObservations[1].Data["key_name"])
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ func (b *backend) pathPolicyBYOKExportRead(ctx context.Context, req *logical.Req
|
|||
}
|
||||
|
||||
retKeys := map[string]string{}
|
||||
var exportVersion *int
|
||||
switch version {
|
||||
case "":
|
||||
for k, v := range srcP.Keys {
|
||||
|
|
@ -139,8 +140,16 @@ func (b *backend) pathPolicyBYOKExportRead(ctx context.Context, req *logical.Req
|
|||
}
|
||||
|
||||
retKeys[strconv.Itoa(versionValue)] = exportKey
|
||||
exportVersion = &versionValue
|
||||
}
|
||||
|
||||
metadata := b.keyPolicyObservationMetadata(srcP)
|
||||
if exportVersion != nil {
|
||||
metadata["export_version"] = *exportVersion
|
||||
}
|
||||
metadata["destination_key"] = dstP.Name
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyExportBYOK, metadata)
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": srcP.Name,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_BYOKExportImport(t *testing.T) {
|
||||
|
|
@ -36,7 +37,7 @@ func testBYOKExportImport(t *testing.T, keyType, feature string) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
// Create a key
|
||||
keyReq := &logical.Request{
|
||||
|
|
@ -92,6 +93,11 @@ func testBYOKExportImport(t *testing.T, keyType, feature string) {
|
|||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyExportBYOK)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, "test-source", obs.Data["key_name"])
|
||||
require.Equal(t, "wrapper", obs.Data["destination_key"])
|
||||
keys := resp.Data["keys"].(map[string]string)
|
||||
|
||||
// Import the key to a new name.
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
|
|||
retKeys[k] = exportKey
|
||||
}
|
||||
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyExport, b.keyPolicyObservationMetadata(p))
|
||||
default:
|
||||
var versionValue int
|
||||
if version == "latest" {
|
||||
|
|
@ -155,6 +156,9 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
|
|||
}
|
||||
|
||||
retKeys[strconv.Itoa(versionValue)] = exportKey
|
||||
metadata := b.keyPolicyObservationMetadata(p)
|
||||
metadata["export_version"] = versionValue
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyExport, metadata)
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ func TestTransit_Export_KeyVersion_ExportsCorrectVersion(t *testing.T) {
|
|||
|
||||
func verifyExportsCorrectVersion(t *testing.T, exportType, keyType, parameterSet, ecKeyType string) {
|
||||
t.Run(keyType+":"+ecKeyType, func(t *testing.T) {
|
||||
b, storage := createBackendWithSysView(t)
|
||||
b, storage, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
// First create a key, v1
|
||||
req := &logical.Request{
|
||||
|
|
@ -161,6 +161,10 @@ func verifyExportsCorrectVersion(t *testing.T, exportType, keyType, parameterSet
|
|||
t.Fatalf("expected version %q, received version %q", strconv.Itoa(expectedVersion), k)
|
||||
}
|
||||
}
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyExport)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, obs.Data["key_name"], "foo")
|
||||
require.Equal(t, obs.Data["export_version"], expectedVersion)
|
||||
}
|
||||
|
||||
verifyVersion("v1", 1)
|
||||
|
|
|
|||
|
|
@ -256,6 +256,15 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d *
|
|||
return nil, err
|
||||
}
|
||||
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyImport, map[string]interface{}{
|
||||
"key_name": name,
|
||||
"type": polReq.KeyType,
|
||||
"derived": polReq.Derived,
|
||||
"exportable": polReq.Exportable,
|
||||
"allow_plaintext_backup": polReq.AllowPlaintextBackup,
|
||||
"auto_rotate_period": int64(autoRotatePeriod.Seconds()),
|
||||
})
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
@ -297,14 +306,16 @@ func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Reque
|
|||
return resp, err
|
||||
}
|
||||
|
||||
var versionToUpdate *int
|
||||
// Get param version if set else import a new version.
|
||||
if version, ok := d.GetOk("version"); ok {
|
||||
versionToUpdate := version.(int)
|
||||
versionValue := version.(int)
|
||||
versionToUpdate = &versionValue
|
||||
|
||||
// Check if given version can be updated given input
|
||||
err = p.KeyVersionCanBeUpdated(versionToUpdate, isCiphertextSet)
|
||||
err = p.KeyVersionCanBeUpdated(*versionToUpdate, isCiphertextSet)
|
||||
if err == nil {
|
||||
err = p.ImportPrivateKeyForVersion(ctx, req.Storage, versionToUpdate, key)
|
||||
err = p.ImportPrivateKeyForVersion(ctx, req.Storage, *versionToUpdate, key)
|
||||
}
|
||||
} else {
|
||||
err = p.ImportPublicOrPrivate(ctx, req.Storage, key, isCiphertextSet, b.GetRandomReader())
|
||||
|
|
@ -314,6 +325,12 @@ func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Reque
|
|||
return nil, err
|
||||
}
|
||||
|
||||
metadata := b.keyPolicyObservationMetadata(p)
|
||||
if versionToUpdate != nil {
|
||||
metadata["import_version"] = *versionToUpdate
|
||||
}
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyImport, metadata)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/helper/cryptoutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tink-crypto/tink-go/v2/kwp/subtle"
|
||||
)
|
||||
|
||||
|
|
@ -92,7 +93,7 @@ func getKey(t *testing.T, keyType string) interface{} {
|
|||
|
||||
func TestTransit_ImportNSSEd25519Key(t *testing.T) {
|
||||
generateKeys(t)
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
wrappingKey, err := b.getWrappingKey(context.Background(), s)
|
||||
if err != nil || wrappingKey == nil {
|
||||
|
|
@ -121,11 +122,16 @@ func TestTransit_ImportNSSEd25519Key(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to import NSS-formatted Ed25519 key: %v", err)
|
||||
}
|
||||
|
||||
// Verify observation was recorded
|
||||
importObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyImport)
|
||||
require.Len(t, importObservations, 1)
|
||||
require.Equal(t, "nss-ed25519", importObservations[0].Data["key_name"])
|
||||
}
|
||||
|
||||
func TestTransit_ImportRSAPSS(t *testing.T) {
|
||||
generateKeys(t)
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
wrappingKey, err := b.getWrappingKey(context.Background(), s)
|
||||
if err != nil || wrappingKey == nil {
|
||||
|
|
@ -154,12 +160,21 @@ func TestTransit_ImportRSAPSS(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to import RSA-PSS private key: %v", err)
|
||||
}
|
||||
|
||||
importObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyImport)
|
||||
require.Len(t, importObservations, 1)
|
||||
require.Equal(t, "rsa-pss", importObservations[0].Data["key_name"])
|
||||
}
|
||||
|
||||
func TestTransit_Import(t *testing.T) {
|
||||
generateKeys(t)
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
checkImportObservation := func(t *testing.T, keyName string) {
|
||||
t.Helper()
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyImport)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, keyName, obs.Data["key_name"])
|
||||
}
|
||||
t.Run(
|
||||
"import into a key fails before wrapping key is read",
|
||||
func(t *testing.T) {
|
||||
|
|
@ -259,6 +274,7 @@ func TestTransit_Import(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to import valid key: %s", err)
|
||||
}
|
||||
checkImportObservation(t, keyID)
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -380,6 +396,7 @@ func TestTransit_Import(t *testing.T) {
|
|||
t.Fatalf("failed to import key: %s", err)
|
||||
}
|
||||
|
||||
checkImportObservation(t, keyID)
|
||||
// Rotate key
|
||||
req = &logical.Request{
|
||||
Storage: s,
|
||||
|
|
@ -390,6 +407,10 @@ func TestTransit_Import(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to rotate key: %s", err)
|
||||
}
|
||||
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyRotateSuccess)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, obs.Data["key_name"], keyID)
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -418,6 +439,8 @@ func TestTransit_Import(t *testing.T) {
|
|||
t.Fatalf("failed to import key: %s", err)
|
||||
}
|
||||
|
||||
checkImportObservation(t, keyID)
|
||||
|
||||
// Rotate key
|
||||
req = &logical.Request{
|
||||
Storage: s,
|
||||
|
|
@ -461,6 +484,7 @@ func TestTransit_Import(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to import ed25519 key: %v", err)
|
||||
}
|
||||
checkImportObservation(t, keyID)
|
||||
})
|
||||
|
||||
t.Run(
|
||||
|
|
@ -493,12 +517,13 @@ func TestTransit_Import(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to import public key: %s", err)
|
||||
}
|
||||
checkImportObservation(t, keyID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransit_ImportVersion(t *testing.T) {
|
||||
generateKeys(t)
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
t.Run(
|
||||
"import into a key version fails before wrapping key is read",
|
||||
|
|
@ -686,13 +711,19 @@ func TestTransit_ImportVersion(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to update key: %s", err)
|
||||
}
|
||||
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyImport)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, keyID, obs.Data["key_name"])
|
||||
require.Equal(t, keyType, obs.Data["type"])
|
||||
require.NotContains(t, obs.Data, "import_version")
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestTransit_ImportVersionWithPublicKeys(t *testing.T) {
|
||||
generateKeys(t)
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
// Retrieve public wrapping key
|
||||
wrappingKey, err := b.getWrappingKey(context.Background(), s)
|
||||
|
|
@ -931,6 +962,11 @@ func TestTransit_ImportVersionWithPublicKeys(t *testing.T) {
|
|||
t.Fatalf("failed to import private key: %s", err)
|
||||
}
|
||||
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyImport)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, keyID, obs.Data["key_name"])
|
||||
require.Equal(t, 1, obs.Data["import_version"])
|
||||
|
||||
// We should still have two keys on export
|
||||
req = &logical.Request{
|
||||
Storage: s,
|
||||
|
|
|
|||
|
|
@ -348,6 +348,12 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d *
|
|||
if !upserted {
|
||||
resp.AddWarning(fmt.Sprintf("key %s already existed", name))
|
||||
}
|
||||
|
||||
metadata := b.keyPolicyObservationMetadata(p)
|
||||
if polReq.ManagedKeyUUID != "" {
|
||||
metadata["managed_key_id"] = polReq.ManagedKeyUUID
|
||||
}
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyWrite, metadata)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
@ -387,9 +393,51 @@ func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *f
|
|||
}
|
||||
}
|
||||
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRead, b.keyPolicyObservationMetadata(p))
|
||||
return b.formatKeyPolicy(p, context)
|
||||
}
|
||||
|
||||
func (b *backend) keyPolicyObservationMetadata(p *keysutil.Policy) map[string]interface{} {
|
||||
metadata := map[string]interface{}{
|
||||
"key_name": p.Name,
|
||||
"type": p.Type.String(),
|
||||
"derived": p.Derived,
|
||||
"deletion_allowed": p.DeletionAllowed,
|
||||
"min_available_version": p.MinAvailableVersion,
|
||||
"min_decryption_version": p.MinDecryptionVersion,
|
||||
"min_encryption_version": p.MinEncryptionVersion,
|
||||
"latest_version": p.LatestVersion,
|
||||
"exportable": p.Exportable,
|
||||
"allow_plaintext_backup": p.AllowPlaintextBackup,
|
||||
"auto_rotate_period": int64(p.AutoRotatePeriod.Seconds()),
|
||||
"imported_key": p.Imported,
|
||||
}
|
||||
|
||||
if p.Derived {
|
||||
switch p.KDF {
|
||||
case keysutil.Kdf_hmac_sha256_counter:
|
||||
metadata["kdf"] = "hmac-sha256-counter"
|
||||
metadata["kdf_mode"] = "hmac-sha256-counter"
|
||||
case keysutil.Kdf_hkdf_sha256:
|
||||
metadata["kdf"] = "hkdf_sha256"
|
||||
}
|
||||
metadata["convergent_encryption"] = p.ConvergentEncryption
|
||||
if p.ConvergentEncryption {
|
||||
metadata["convergent_encryption_version"] = p.ConvergentVersion
|
||||
}
|
||||
}
|
||||
|
||||
if p.ParameterSet != "" {
|
||||
metadata["parameter_set"] = p.ParameterSet
|
||||
}
|
||||
|
||||
if p.Type == keysutil.KeyType_HYBRID {
|
||||
metadata["hybrid_key_type_pqc"] = p.HybridConfig.PQCKeyType.String()
|
||||
metadata["hybrid_key_type_ec"] = p.HybridConfig.ECKeyType.String()
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func (b *backend) formatKeyPolicy(p *keysutil.Policy, context []byte) (*logical.Response, error) {
|
||||
// Return the response
|
||||
resp := &logical.Response{
|
||||
|
|
@ -548,6 +596,9 @@ func (b *backend) pathPolicyDelete(ctx context.Context, req *logical.Request, d
|
|||
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
|
||||
}
|
||||
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyDelete, map[string]interface{}{
|
||||
"key_name": name,
|
||||
})
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,16 @@ func (b *backend) pathRestoreUpdate(ctx context.Context, req *logical.Request, d
|
|||
return nil, ErrInvalidKeyName
|
||||
}
|
||||
|
||||
return nil, b.lm.RestorePolicy(ctx, req.Storage, keyName, backupB64, force)
|
||||
restoredKeyName, err := b.lm.RestorePolicy(ctx, req.Storage, keyName, backupB64, force)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRestore, map[string]interface{}{
|
||||
"key_name": restoredKeyName,
|
||||
"force": force,
|
||||
})
|
||||
// ignore-nil-nil-function-check
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/helper/testhelpers"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_Restore(t *testing.T) {
|
||||
|
|
@ -25,7 +26,7 @@ func TestTransit_Restore(t *testing.T) {
|
|||
// as if the key already existed
|
||||
|
||||
keyType := "aes256-gcm96"
|
||||
b, s := createBackendWithStorage(t)
|
||||
b, s, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
keyName := testhelpers.RandomWithPrefix("my-key")
|
||||
|
||||
// Create a key
|
||||
|
|
@ -227,6 +228,16 @@ func TestTransit_Restore(t *testing.T) {
|
|||
readKeyName = tc.RestoreName
|
||||
}
|
||||
|
||||
if tc.ExpectedErr == nil {
|
||||
obs := obsRecorder.LastObservationOfType(ObservationTypeTransitKeyRestore)
|
||||
require.NotNil(t, obs)
|
||||
require.Equal(t, obs.Data["key_name"], readKeyName)
|
||||
force := false
|
||||
if tc.Force != nil {
|
||||
force = *tc.Force
|
||||
}
|
||||
require.Equal(t, obs.Data["force"], force)
|
||||
}
|
||||
// read the key and make sure it's there
|
||||
readReq := &logical.Request{
|
||||
Path: "keys/" + readKeyName,
|
||||
|
|
|
|||
|
|
@ -68,9 +68,8 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d *
|
|||
p.Lock(true)
|
||||
}
|
||||
defer p.Unlock()
|
||||
|
||||
var keyId string
|
||||
if p.Type == keysutil.KeyType_MANAGED_KEY {
|
||||
var keyId string
|
||||
keyId, err = GetManagedKeyUUID(ctx, b, managedKeyName, managedKeyId)
|
||||
if err != nil {
|
||||
b.Logger().Error("failed to rotate key", "name", name, "error", err.Error())
|
||||
|
|
@ -82,17 +81,26 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d *
|
|||
err = p.Rotate(ctx, req.Storage, b.GetRandomReader())
|
||||
}
|
||||
|
||||
keyMetadata := b.keyPolicyObservationMetadata(p)
|
||||
if p.Type == keysutil.KeyType_MANAGED_KEY && keyId != "" {
|
||||
keyMetadata["managed_key_id"] = keyId
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
b.Logger().Error("failed to rotate key on user request", "name", name, "error", err.Error())
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRotateFail, keyMetadata)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := b.formatKeyPolicy(p, nil)
|
||||
if err != nil {
|
||||
b.Logger().Error("failed to rotate key on user request", "name", name, "error", err.Error())
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRotateFail, keyMetadata)
|
||||
} else {
|
||||
b.Logger().Info("successfully rotated key on user request", "name", name)
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyRotateSuccess, keyMetadata)
|
||||
}
|
||||
|
||||
// formatKeyPolicy returns a response even on error so be sure to return both.
|
||||
return resp, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,6 +100,8 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
b.TryRecordObservationWithRequest(ctx, req, ObservationTypeTransitKeyTrim, b.keyPolicyObservationMetadata(p))
|
||||
|
||||
return b.formatKeyPolicy(p, nil)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ import (
|
|||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_Trim(t *testing.T) {
|
||||
b, storage := createBackendWithSysView(t)
|
||||
b, storage, obsRecorder := createBackendWithObservationRecorder(t)
|
||||
|
||||
doReq := func(t *testing.T, req *logical.Request) *logical.Response {
|
||||
t.Helper()
|
||||
|
|
@ -270,4 +271,9 @@ func TestTransit_Trim(t *testing.T) {
|
|||
if len(archive.Keys) != 4 {
|
||||
t.Fatalf("bad: len of archived keys; expected: 4, actual: %d", len(archive.Keys))
|
||||
}
|
||||
|
||||
trimObservations := obsRecorder.ObservationsByType(ObservationTypeTransitKeyTrim)
|
||||
require.Len(t, trimObservations, 2)
|
||||
require.Equal(t, trimObservations[0].Data["key_name"], "aes")
|
||||
require.Equal(t, trimObservations[1].Data["key_name"], "aes")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,21 +150,21 @@ func (lm *LockManager) InitCache(cacheSize int) error {
|
|||
|
||||
// RestorePolicy acquires an exclusive lock on the policy name and restores the
|
||||
// given policy along with the archive.
|
||||
func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
|
||||
func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) (string, error) {
|
||||
backupBytes, err := base64.StdEncoding.DecodeString(backup)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
var keyData KeyData
|
||||
err = jsonutil.DecodeJSON(backupBytes, &keyData)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the policy exists in the backup data
|
||||
if keyData.Policy == nil {
|
||||
return errors.New("backup data does not contain a valid policy")
|
||||
return "", errors.New("backup data does not contain a valid policy")
|
||||
}
|
||||
|
||||
// Set a different name if desired
|
||||
|
|
@ -188,7 +188,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
|
|||
if lm.useCache {
|
||||
pRaw, ok = lm.cache.Load(name)
|
||||
if ok && !force {
|
||||
return fmt.Errorf("key %q already exists", name)
|
||||
return "", fmt.Errorf("key %q already exists", name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -207,10 +207,10 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
|
|||
if pRaw == nil {
|
||||
p, err = lm.getPolicyFromStorage(ctx, storage, name)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
if p != nil && !force {
|
||||
return fmt.Errorf("key %q already exists", name)
|
||||
return "", fmt.Errorf("key %q already exists", name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
|
|||
if keyData.ArchivedKeys != nil {
|
||||
err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
|
||||
return "", errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
|
|||
// Restore the policy. This will also attempt to adjust the archive.
|
||||
err = keyData.Policy.Persist(ctx, storage)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
|
||||
return "", errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
|
||||
}
|
||||
|
||||
keyData.Policy.l = new(sync.RWMutex)
|
||||
|
|
@ -252,7 +252,7 @@ func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storag
|
|||
if lm.useCache {
|
||||
lm.cache.Store(name, keyData.Policy)
|
||||
}
|
||||
return nil
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ func TestRestorePolicy_NilPolicy(t *testing.T) {
|
|||
// Create backup data without "policy" field (causes nil Policy)
|
||||
invalidBackup := base64.StdEncoding.EncodeToString([]byte(`{"archived_keys": null}`))
|
||||
|
||||
err = lm.RestorePolicy(ctx, storage, "test-key", invalidBackup, false)
|
||||
_, err = lm.RestorePolicy(ctx, storage, "test-key", invalidBackup, false)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "backup data does not contain a valid policy")
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue